From 7c7b02ce9103f07655118beef0cedd5c24466109 Mon Sep 17 00:00:00 2001 From: Xingyou Song Date: Wed, 24 Jan 2024 12:17:31 -0800 Subject: [PATCH] Add GP_UCB_PE and make it the new default. PiperOrigin-RevId: 601196580 --- README.md | 4 +- docs/guides/user/running_vizier.ipynb | 2 +- docs/guides/user/supported_algorithms.ipynb | 13 +- vizier/__init__.py | 2 +- vizier/_src/algorithms/designers/gp_ucb_pe.py | 825 ++++++++++++++++++ vizier/_src/pyvizier/oss/study_config.py | 4 +- vizier/_src/service/policy_factory.py | 7 +- vizier/_src/service/vizier_client_test.py | 6 +- vizier/algorithms/designers/__init__.py | 1 + vizier/pyvizier/converters/padding.py | 4 +- 10 files changed, 851 insertions(+), 17 deletions(-) create mode 100644 vizier/_src/algorithms/designers/gp_ucb_pe.py diff --git a/README.md b/README.md index d5f0295a7..7c8e55935 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ def evaluate(w: float, x: int, y: float, z: str) -> float: return w**2 - y**2 + x * ord(z) # Algorithm, search space, and metrics. -study_config = vz.StudyConfig(algorithm='GAUSSIAN_PROCESS_BANDIT') +study_config = vz.StudyConfig(algorithm='DEFAULT') study_config.search_space.root.add_float_param('w', 0.0, 5.0) study_config.search_space.root.add_int_param('x', -2, 2) study_config.search_space.root.add_discrete_param('y', [0.3, 7.2]) @@ -46,7 +46,7 @@ study_config.metric_information.append(vz.MetricInformation('metric_name', goal= # Setup client and begin optimization. Vizier Service will be implicitly created. study = clients.Study.from_study_config(study_config, owner='my_name', study_id='example') for i in range(10): - suggestions = study.suggest(count=1) + suggestions = study.suggest(count=2) for suggestion in suggestions: params = suggestion.parameters objective = evaluate(params['w'], params['x'], params['y'], params['z']) diff --git a/docs/guides/user/running_vizier.ipynb b/docs/guides/user/running_vizier.ipynb index 4103e4f30..5d4184e01 100644 --- a/docs/guides/user/running_vizier.ipynb +++ b/docs/guides/user/running_vizier.ipynb @@ -108,7 +108,7 @@ "outputs": [], "source": [ "study_config = vz.StudyConfig.from_problem(problem)\n", - "study_config.algorithm = 'GAUSSIAN_PROCESS_BANDIT'" + "study_config.algorithm = 'DEFAULT'" ] }, { diff --git a/docs/guides/user/supported_algorithms.ipynb b/docs/guides/user/supported_algorithms.ipynb index b89db49d0..d709a881e 100644 --- a/docs/guides/user/supported_algorithms.ipynb +++ b/docs/guides/user/supported_algorithms.ipynb @@ -19,12 +19,13 @@ "## Official\n", "The following algorithms can be considered \"official\" and production-quality:\n", "\n", - "1. [**GP-Bandit**](https://github.com/google/vizier/blob/main/vizier/_src/algorithms/designers/gp_bandit.py) (`GAUSSIAN_PROCESS_BANDIT`): Flat Search Spaces.\n", - "2. [**Random Search**](https://github.com/google/vizier/blob/main/vizier/_src/algorithms/designers/random.py) (`RANDOM_SEARCH`): Flat Search Spaces.\n", - "3. [**Quasi-Random Search**](https://github.com/google/vizier/blob/main/vizier/_src/algorithms/designers/quasi_random.py) (`QUASI_RANDOM_SEARCH`): Flat Search Spaces.\n", - "4. [**Grid Search**](https://github.com/google/vizier/blob/main/vizier/_src/algorithms/designers/grid.py) (`GRID_SEARCH`): Flat Search Spaces.\n", - "5. [**Shuffled Grid Search**](https://github.com/google/vizier/blob/main/vizier/_src/algorithms/designers/grid.py) (`SHUFFLED_GRID_SEARCH`): Flat Search Spaces.\n", - "6. [**Eagle Strategy**](https://github.com/google/vizier/blob/main/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy.py) (`EAGLE_STRATEGY`): Flat Search Spaces." + "1. [**GP-UCB-PE**](https://github.com/google/vizier/blob/main/vizier/_src/algorithms/designers/gp_ucb_pe.py) (`GP_UCB_PE`): Flat Search Spaces.\n", + "2. [**GP-Bandit**](https://github.com/google/vizier/blob/main/vizier/_src/algorithms/designers/gp_bandit.py) (`GAUSSIAN_PROCESS_BANDIT`): Flat Search Spaces.\n", + "3. [**Random Search**](https://github.com/google/vizier/blob/main/vizier/_src/algorithms/designers/random.py) (`RANDOM_SEARCH`): Flat Search Spaces.\n", + "4. [**Quasi-Random Search**](https://github.com/google/vizier/blob/main/vizier/_src/algorithms/designers/quasi_random.py) (`QUASI_RANDOM_SEARCH`): Flat Search Spaces.\n", + "5. [**Grid Search**](https://github.com/google/vizier/blob/main/vizier/_src/algorithms/designers/grid.py) (`GRID_SEARCH`): Flat Search Spaces.\n", + "6. [**Shuffled Grid Search**](https://github.com/google/vizier/blob/main/vizier/_src/algorithms/designers/grid.py) (`SHUFFLED_GRID_SEARCH`): Flat Search Spaces.\n", + "7. [**Eagle Strategy**](https://github.com/google/vizier/blob/main/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy.py) (`EAGLE_STRATEGY`): Flat Search Spaces." ] }, { diff --git a/vizier/__init__.py b/vizier/__init__.py index c6ea18cde..7e6fa7289 100644 --- a/vizier/__init__.py +++ b/vizier/__init__.py @@ -23,4 +23,4 @@ sys.path.append(PROTO_ROOT) -__version__ = "0.1.14" +__version__ = "0.1.15" diff --git a/vizier/_src/algorithms/designers/gp_ucb_pe.py b/vizier/_src/algorithms/designers/gp_ucb_pe.py new file mode 100644 index 000000000..d1e1a46c4 --- /dev/null +++ b/vizier/_src/algorithms/designers/gp_ucb_pe.py @@ -0,0 +1,825 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +"""Gaussian Process Bandit with Pure Exploration using a Flax model and a TFP Gaussian Process.""" + +# pylint: disable=logging-fstring-interpolation, g-long-lambda + +import copy +import datetime +import random +from typing import Any, Callable, Mapping, Optional, Sequence, Union + +from absl import logging +import attr +import chex +import equinox as eqx +import jax +from jax import numpy as jnp +import jaxtyping as jt +import numpy as np +from tensorflow_probability.substrates import jax as tfp # pylint: disable=g-importing-member +from vizier import algorithms as vza +from vizier import pyvizier as vz +from vizier._src.algorithms.designers import quasi_random +from vizier._src.algorithms.designers.gp import acquisitions +from vizier._src.algorithms.designers.gp import output_warpers +from vizier._src.algorithms.optimizers import eagle_strategy as es +from vizier._src.algorithms.optimizers import vectorized_base as vb +from vizier._src.jax import stochastic_process_model as sp +from vizier._src.jax import types +from vizier._src.jax.models import tuned_gp_models +from vizier.jax import optimizers +from vizier.pyvizier import converters +from vizier.pyvizier.converters import padding +from vizier.utils import profiler + + +tfd = tfp.distributions + + +class UCBPEConfig(eqx.Module): + """UCB-PE config parameters.""" + + ucb_coefficient: jt.Float[jt.Array, ''] = eqx.field( + default=1.8, converter=jnp.asarray + ) + # A separate ucb coefficient defining the good region to explore. + explore_region_ucb_coefficient: jt.Float[jt.Array, ''] = eqx.field( + default=0.5, converter=jnp.asarray + ) + # The constraint violation penalty is a linear function of the constraint + # violation, whose slope is determined by this coefficient. + cb_violation_penalty_coefficient: jt.Float[jt.Array, ''] = eqx.field( + default=10.0, converter=jnp.asarray + ) + # Probability of using empty pending trials during batched suggestions. + ucb_overwrite_probability: jt.Float[jt.Array, ''] = eqx.field( + default=0.25, converter=jnp.asarray + ) + + def __repr__(self): + return eqx.tree_pformat(self, short_arrays=False) + + +# A dummy loss for ARD when there are no completed trials. +_DUMMY_LOSS = -1.0 + + +def _has_new_completed_trials( + completed_trials: Sequence[vz.Trial], active_trials: Sequence[vz.Trial] +) -> bool: + """Returns True iff there are newer completed trials than active trials. + + Args: + completed_trials: Completed trials. + active_trials: Active trials. + + Returns: + True if `completed_trials` is non-empty and: + - `active_trials` is empty, or + - The latest `completion_time` among `completed_trials` is + later than the latest `creation_time` among `active_trials`. + False: otherwise. + """ + + if not completed_trials: + return False + if not active_trials: + return True + + completed_completion_times = [t.completion_time for t in completed_trials] + active_creation_times = [t.creation_time for t in active_trials] + + if not all(completed_completion_times): + raise ValueError('All completed trials must have completion times.') + if not all(active_creation_times): + raise ValueError('All active trials must have creation times.') + + return max(completed_completion_times) > max(active_creation_times) # pytype:disable=unsupported-operands + + +def _compute_ucb_threshold( + gprm: tfd.Distribution, + is_missing: jt.Bool[jt.Array, ''], + ucb_coefficient: jt.Float[jt.Array, ''], +) -> jax.Array: + """Computes a threshold on UCB values. + + A promising evaluation point has UCB value no less than the threshold + computed here. The threshold is the predicted mean of the feature array + with the maximum UCB value among the points `gprm.index_points`. + + Args: + gprm: A GP regression model for a set of predictive index points. + is_missing: A 1-d boolean array indicating whether the corresponding + predictive index points are missing. + ucb_coefficient: The UCB coefficient. + + Returns: + The predicted mean of the feature array with the maximum UCB among `xs`. + """ + pred_mean = gprm.mean() + ucb_values = jnp.where( + is_missing, -jnp.inf, pred_mean + ucb_coefficient * gprm.stddev() + ) + return pred_mean[jnp.argmax(ucb_values)] + + +# TODO: Use acquisitions.TrustRegion instead. +def _apply_trust_region( + tr: acquisitions.TrustRegion, xs: types.ModelInput, acq_values: jax.Array +) -> jax.Array: + """Applies the trust region to acquisition function values. + + Args: + tr: Trust region. + xs: Predictive index points. + acq_values: Acquisition function values at predictive index points. + + Returns: + Acquisition function values with trust region applied. + """ + distance = tr.min_linf_distance(xs) + # Due to output normalization, acquisition values can't be as low as -1e12. + # We use a bad value that decreases in the distance to trust region so that + # acquisition optimizer can follow the gradient and escape untrustred regions. + return jnp.where( + (distance < tr.trust_radius) | (tr.trust_radius > 0.5), + acq_values, + -1e12 - distance, + ) + + +def _get_features_shape( + features: types.ModelInput, +) -> types.ContinuousAndCategorical: + """Gets the shapes of continuous/categorical features for logging.""" + return types.ContinuousAndCategorical( + features.continuous.shape, + features.categorical.shape, + ) + + +class UCBScoreFunction(eqx.Module): + """Computes the UCB acquisition value. + + The UCB acquisition value is the sum of the predicted mean based on completed + trials and the predicted standard deviation based on all trials, completed and + pending (scaled by the UCB coefficient). This class follows the + `acquisitions.ScoreFunction` protocol. + + Attributes: + predictive: Predictive model with cached Cholesky conditioned on completed + trials. + predictive_all_features: Predictive model with cached Cholesky conditioned + on completed and pending trials. + ucb_coefficient: The UCB coefficient. + trust_region: Trust region. + """ + + predictive: sp.UniformEnsemblePredictive + predictive_all_features: sp.UniformEnsemblePredictive + ucb_coefficient: jt.Float[jt.Array, ''] + trust_region: Optional[acquisitions.TrustRegion] + + def score( + self, xs: types.ModelInput, seed: Optional[jax.Array] = None + ) -> jax.Array: + return self.score_with_aux(xs, seed=seed)[0] + + def aux( + self, xs: types.ModelInput, seed: Optional[jax.Array] = None + ) -> chex.ArrayTree: + return self.score_with_aux(xs, seed=seed)[1] + + def score_with_aux( + self, xs: types.ModelInput, seed: Optional[jax.Array] = None + ) -> tuple[jax.Array, chex.ArrayTree]: + del seed + gprm = self.predictive.predict(xs) + gprm_all_features = self.predictive_all_features.predict(xs) + mean = gprm.mean() + stddev_from_all = gprm_all_features.stddev() + acq_values = mean + self.ucb_coefficient * stddev_from_all + if self.trust_region is not None: + acq_values = _apply_trust_region(self.trust_region, xs, acq_values) + return acq_values, { + 'mean': mean, + 'stddev': gprm.stddev(), + 'stddev_from_all': stddev_from_all, + } + + +class PEScoreFunction(eqx.Module): + """Computes the Pure-Exploration acquisition value. + + The PE acquisition value is the predicted standard deviation based on + all suggestions, completed and pending, plus a penalty term that grows + linearly in the amount of violation of the constraint + `UCB(xs) >= threshold`. This class follows the `acquisitions.ScoreFunction` + protocol. + + Attributes: + predictive: Predictive model with cached Cholesky conditioned on completed + trials. + predictive_all_features: Predictive model with cached Cholesky conditioned + on completed and pending trials. + ucb_coefficient: The UCB coefficient used to compute the threshold. + explore_ucb_coefficient: The UCB coefficient used for computing the UCB + values on `xs`. + penalty_coefficient: Multiplier on the constraint violation penalty. + trust_region: + + Returns: + The Pure-Exploration acquisition value. + """ + + predictive: sp.UniformEnsemblePredictive + predictive_all_features: sp.UniformEnsemblePredictive + ucb_coefficient: jt.Float[jt.Array, ''] + explore_ucb_coefficient: jt.Float[jt.Array, ''] + penalty_coefficient: jt.Float[jt.Array, ''] + trust_region: Optional[acquisitions.TrustRegion] + + def score( + self, xs: types.ModelInput, seed: Optional[jax.Array] = None + ) -> jax.Array: + return self.score_with_aux(xs, seed=seed)[0] + + def aux( + self, xs: types.ModelInput, seed: Optional[jax.Array] = None + ) -> chex.ArrayTree: + return self.score_with_aux(xs, seed=seed)[1] + + def score_with_aux( + self, xs: types.ModelInput, seed: Optional[jax.Array] = None + ) -> tuple[jax.Array, chex.ArrayTree]: + del seed + features = self.predictive_all_features.predictives.observed_data.features + is_missing = ( + features.continuous.is_missing[0] | features.categorical.is_missing[0] + ) + gprm_threshold = self.predictive.predict(features) + threshold = _compute_ucb_threshold( + gprm_threshold, is_missing, self.ucb_coefficient + ) + gprm = self.predictive.predict(xs) + mean = gprm.mean() + stddev = gprm.stddev() + explore_ucb = mean + stddev * self.explore_ucb_coefficient + + gprm_all = self.predictive_all_features.predict(xs) + stddev_from_all = gprm_all.stddev() + acq_values = stddev_from_all + self.penalty_coefficient * jnp.minimum( + explore_ucb - threshold, + 0.0, + ) + if self.trust_region is not None: + acq_values = _apply_trust_region(self.trust_region, xs, acq_values) + return acq_values, { + 'mean': mean, + 'stddev': stddev, + 'stddev_from_all': stddev_from_all, + } + + +def default_ard_optimizer() -> optimizers.Optimizer[types.ParameterDict]: + return optimizers.JaxoptScipyLbfgsB( + optimizers.LbfgsBOptions( + num_line_search_steps=20, + tol=1e-5, + maxiter=500, + ) + ) + + +# TODO: Remove excess use of copy.deepcopy() +@attr.define(auto_attribs=False) +class VizierGPUCBPEBandit(vza.Designer): + """GP_UCB_PE with a flax model. + + Attributes: + problem: Must be a flat study with a single metric. + acquisition_optimizer: + metadata_ns: Metadata namespace that this designer writes to. + use_trust_region: Uses trust region. + ard_optimizer: An optimizer object, which should return a batch of + hyperparameters to be ensembled. + num_seed_trials: If greater than zero, first trial is the center of the + search space. Afterwards, uses quasirandom until this number of trials are + observed. + rng: If not set, uses random numbers. + clear_jax_cache: If True, every `suggest` call clears the Jax cache. + """ + + _problem: vz.ProblemStatement = attr.field(kw_only=False) + _acquisition_optimizer_factory: Union[ + Callable[[Any], vza.GradientFreeOptimizer], vb.VectorizedOptimizerFactory + ] = attr.field( + kw_only=True, + factory=lambda: VizierGPUCBPEBandit.default_acquisition_optimizer_factory, + ) + _metadata_ns: str = attr.field( + default='google_gp_ucb_pe_bandit', kw_only=True + ) + _ensemble_size: Optional[int] = attr.field(default=1, kw_only=True) + _all_completed_trials: list[vz.Trial] = attr.field(factory=list, init=False) + _all_active_trials: Sequence[vz.Trial] = attr.field(factory=list, init=False) + _ard_optimizer: optimizers.Optimizer[types.ParameterDict] = attr.field( + factory=default_ard_optimizer, + kw_only=True, + ) + _ard_random_restarts: int = attr.field(default=4, kw_only=True) + _use_trust_region: bool = attr.field(default=True, kw_only=True) + _num_seed_trials: int = attr.field(default=1, kw_only=True) + _config: UCBPEConfig = attr.field( + factory=UCBPEConfig, + kw_only=True, + ) + _rng: jax.Array = attr.field( + factory=lambda: jax.random.PRNGKey(random.getrandbits(32)), kw_only=True + ) + _clear_jax_cache: bool = attr.field(default=False, kw_only=True) + # Whether to pad all inputs, and what type of schedule to use. This is to + # ensure fewer JIT compilation passes. (Default implies no padding.) + # TODO: Check padding does not affect designer behavior. + _padding_schedule: padding.PaddingSchedule = attr.field( + factory=padding.PaddingSchedule, kw_only=True + ) + + default_acquisition_optimizer_factory = vb.VectorizedOptimizerFactory( + strategy_factory=es.VectorizedEagleStrategyFactory( + eagle_config=es.EagleStrategyConfig( + visibility=3.8531520841659943, + gravity=2.9404850423149256, + negative_gravity=0.019006238408962216, + perturbation=0.27435618161962827, + categorical_perturbation_factor=6.97209495578964, + pure_categorical_perturbation_factor=26.7913174702563, + prob_same_category_without_perturbation=0.9757052640433359, + perturbation_lower_bound=0.0011127128824487108, + penalize_factor=0.6922247141686396, + pool_size_exponent=1.9066662747156424, + mutate_normalization_type=es.MutateNormalizationType.MEAN, + normalization_scale=2.0048423159124265, + prior_trials_pool_pct=0.42033336568130314, + ) + ), + max_evaluations=75000, + suggestion_batch_size=25, + ) + + def __attrs_post_init__(self): + # Extra validations + if self._problem.search_space.is_conditional: + raise ValueError(f'{type(self)} does not support conditional search.') + elif len(self._problem.metric_information) != 1: + raise ValueError(f'{type(self)} works with exactly one metric.') + + # Extra initializations. + # Discrete parameters are continuified to account for their actual values. + self._converter = converters.TrialToModelInputConverter.from_problem( + self._problem, + scale=True, + max_discrete_indices=0, + flip_sign_for_minimization_metrics=True, + padding_schedule=self._padding_schedule, + ) + qrs_seed, self._rng = jax.random.split(self._rng) + self._quasi_random_sampler = quasi_random.QuasiRandomDesigner( + self._problem.search_space, + seed=int(jax.random.randint(qrs_seed, [], 0, 2**16)), + ) + + def update( + self, completed: vza.CompletedTrials, all_active: vza.ActiveTrials + ) -> None: + self._all_completed_trials.extend(copy.deepcopy(completed.trials)) + self._all_active_trials = copy.deepcopy(all_active.trials) + + @property + def _metric_info(self) -> vz.MetricInformation: + return self._problem.metric_information.item() + + def _generate_seed_trials(self, count: int) -> Sequence[vz.TrialSuggestion]: + """Generate seed trials. + + The first seed trial is chosen as the search space center, the rest of the + seed trials are chosen quasi-randomly. + + Arguments: + count: The number of seed trials. + + Returns: + The seed trials. + """ + seed_suggestions = [] + if (not self._all_completed_trials) and (not self._all_active_trials): + features = self._converter.to_features([]) # to extract shape. + # NOTE: The code below assumes that a scaled value of 0.5 corresponds + # to the center of the feasible range. This is true, but only by accident; + # ideally, we should get the center from the converters. + parameters = self._converter.to_parameters( + types.ModelInput( + continuous=self._padding_schedule.pad_features( + 0.5 * np.ones([1, features.continuous.shape[1]]) + ), + categorical=self._padding_schedule.pad_features( + np.zeros( + [1, features.categorical.shape[1]], dtype=types.INT_DTYPE + ) + ), + ) + )[0] + seed_suggestions.append( + vz.TrialSuggestion( + parameters, metadata=vz.Metadata({'seeded': 'center'}) + ) + ) + if (remaining_counts := count - len(seed_suggestions)) > 0: + seed_suggestions.extend( + self._quasi_random_sampler.suggest(remaining_counts) + ) + return seed_suggestions + + @profiler.record_runtime( + name_prefix='VizierGPUCBPEBandit', + name='build_gp_model_and_optimize_parameters', + ) + def _build_gp_model_and_optimize_parameters( + self, data: types.ModelData, rng: jax.Array + ) -> sp.StochasticProcessWithCoroutine: + """Builds a GP model and optimizes parameters. + + Args: + data: Observed features and labels. + rng: A key for random number generation. + + Returns: + A tuple of GP model and its parameters optimized over `data.features` and + `data.labels`. If `data.features` is empty, the returned parameters are + initial values picked by the GP model. + """ + # TODO: Update to `VizierLinearGaussianProcess`. + coroutine = tuned_gp_models.VizierGaussianProcess.build_model( + data.features + ).coroutine + model = sp.CoroutineWithData(coroutine, data) + + if (data.features.continuous.padded_array.shape[0] == 0) and ( + data.features.categorical.padded_array.shape[0] == 0 + ): + # This happens when `suggest` is called after the seed trials are + # generated without any completed trials. In this case, the designer + # uses the PE acquisition, but still needs a GP to do that. By using a + # dummy loss here, the ARD optimizer is expected to return the initial + # values it uses for the parameters. + ard_loss_with_aux = lambda _: (_DUMMY_LOSS, dict()) + else: + ard_loss_with_aux = model.loss_with_aux + + logging.info( + 'Optimizing the loss function on features with shape ' + f'{_get_features_shape(data.features)} and labels with shape ' + f'{data.labels.shape}...' + ) + constraints = sp.get_constraints(model) + rng, init_rng = jax.random.split(rng, 2) + random_init_params = eqx.filter_jit(eqx.filter_vmap(model.setup))( + jax.random.split(init_rng, self._ard_random_restarts) + ) + fixed_init_params = { + 'signal_variance': jnp.array([0.039]), + 'observation_noise_variance': jnp.array([0.0039]), + 'continuous_length_scale_squared': jnp.array( + [[1.0] * data.features.continuous.padded_array.shape[-1]] + ), + 'categorical_length_scale_squared': jnp.array( + [[1.0] * data.features.categorical.padded_array.shape[-1]] + ), + } + best_n = self._ensemble_size or 1 + optimal_params, metrics = self._ard_optimizer( + init_params=jax.tree_map( + lambda x, y: jnp.concatenate([x, y]), + fixed_init_params, + random_init_params, + ), + loss_fn=ard_loss_with_aux, + rng=rng, + constraints=constraints, + best_n=best_n, + ) + # The `"loss"` field of the `metrics` output of ARD optimizers contains an + # array of losses of shape `[num_steps, num_random_restarts]` (or + # `[1, num_random_restarts]` if only the final loss is recorded). + if jnp.any(metrics['loss'][-1, :].argsort()[:best_n] == 0): + logging.info( + 'Parameters found by fixed initialization are among the best' + f' {best_n} parameters.' + ) + else: + logging.info( + f'The best {best_n} parameters were all found by random' + ' initialization.' + ) + + logging.info('Optimal parameters: %s', optimal_params) + return sp.StochasticProcessWithCoroutine( + coroutine=coroutine, params=optimal_params + ) + + def get_score_fn_on_trials( + self, score_fn: Callable[[types.ModelInput], jax.Array] + ) -> Callable[[Sequence[vz.Trial]], Mapping[str, jax.Array]]: + """Builds a callable that evaluates the score function on trials. + + Args: + score_fn: Score function that takes arrays as input. + + Returns: + Score function that takes trials as input. + """ + + def acquisition(trials: Sequence[vz.Trial]) -> Mapping[str, jax.Array]: + jax_acquisitions = eqx.filter_jit(score_fn)( + self._converter.to_features(trials) + ) + return {'acquisition': jax_acquisitions} + + return acquisition + + @profiler.record_runtime + def _trials_to_data(self, trials: Sequence[vz.Trial]) -> types.ModelData: + """Convert trials to scaled features and warped labels.""" + # TrialToArrayConverter returns floating arrays. + data = self._converter.to_xy(trials) + logging.info( + 'Transforming the labels of shape %s. Features has shape: %s', + data.labels.shape, + _get_features_shape(data.features), + ) + warped_labels = output_warpers.create_default_warper().warp( + np.array(data.labels.unpad()) + ) + labels = types.PaddedArray.from_array( + warped_labels, + data.labels.padded_array.shape, + fill_value=data.labels.fill_value, + ) + logging.info('Transformed the labels. Now has shape: %s', labels.shape) + return types.ModelData(features=data.features, labels=labels) + + @profiler.record_runtime( + name_prefix='VizierGPUCBPEBandit', name='get_predictive_all_features' + ) + def _get_predictive_all_features( + self, + pending_features: types.ModelInput, + data: types.ModelData, + model: sp.StochasticProcessWithCoroutine, + ) -> sp.UniformEnsemblePredictive: + """Builds the predictive model conditioned on observed and pending features. + + Args: + pending_features: Pending features. + data: Features/labels for completed trials. + model: The GP model. + + Returns: + Predictive model with cached Cholesky conditioned on observed and pending + features. + """ + # TODO: Use `PaddedArray.concatenate` when implemented. + all_features_continuous = jnp.concatenate( + [ + data.features.continuous.unpad(), + pending_features.continuous.unpad(), + ], + axis=0, + ) + all_features_categorical = jnp.concatenate( + [ + data.features.categorical.unpad(), + pending_features.categorical.unpad(), + ], + axis=0, + ) + all_features = types.ModelInput( + continuous=self._padding_schedule.pad_features(all_features_continuous), + categorical=self._padding_schedule.pad_features( + all_features_categorical + ), + ) + # Pending features are only used to predict standard deviation, so their + # labels do not matter, and we simply set them to 0. + dummy_labels = jnp.zeros( + shape=(pending_features.continuous.unpad().shape[0], 1), + dtype=data.labels.padded_array.dtype, + ) + all_labels = jnp.concatenate([data.labels.unpad(), dummy_labels], axis=0) + all_labels = self._padding_schedule.pad_labels(all_labels) + all_data = types.ModelData(features=all_features, labels=all_labels) + return sp.UniformEnsemblePredictive( + predictives=eqx.filter_jit(model.precompute_predictive)(all_data) + ) + + @profiler.record_runtime(name_prefix='VizierGPUCBPEBandit', name='suggest') + def suggest( + self, count: Optional[int] = None + ) -> Sequence[vz.TrialSuggestion]: + count = count or 1 + if ( + len(self._all_completed_trials) + len(self._all_active_trials) + < self._num_seed_trials + ): + return self._generate_seed_trials(count) + + if self._clear_jax_cache: + jax.clear_caches() + + self._rng, rng = jax.random.split(self._rng, 2) + begin = datetime.datetime.now() + data = self._trials_to_data(self._all_completed_trials) + model = self._build_gp_model_and_optimize_parameters(data, rng) + predictive = sp.UniformEnsemblePredictive( + predictives=eqx.filter_jit(model.precompute_predictive)(data) + ) + + # Optimize acquisition. + active_trial_features = self._converter.to_features(self._all_active_trials) + + tr_features = types.ModelInput( + continuous=self._padding_schedule.pad_features( + jnp.concatenate( + [ + data.features.continuous.unpad(), + active_trial_features.continuous.unpad(), + ], + axis=0, + ) + ), + categorical=self._padding_schedule.pad_features( + jnp.concatenate( + [ + data.features.categorical.unpad(), + active_trial_features.categorical.unpad(), + ], + axis=0, + ), + ), + ) + tr = acquisitions.TrustRegion(trusted=tr_features) + + acquisition_problem = copy.deepcopy(self._problem) + acquisition_problem.metric_information = [ + vz.MetricInformation( + name='acquisition', goal=vz.ObjectiveMetricGoal.MAXIMIZE + ) + ] + logging.info('Optimizing acquisition...') + + # TODO: Feed the eagle strategy with completed trials. + # TODO: Change budget based on requested suggestion count. + suggestions = [] + active_trials = list(self._all_active_trials) + for _ in range(count): + self._rng, rng = jax.random.split(self._rng, 2) + ucb_overwrite = jax.random.bernoulli( + key=rng, p=self._config.ucb_overwrite_probability + ) + # Optimize the UCB acquisition when there are trials completed after all + # active trials were created, or when `ucb_overwrite` is true. The + # `ucb_overwrite_probability` config parameter should be set to a small + # positive value so that the UCB acquisition function is optimized for + # more than one but not too many suggestions in a batch suggestion + # request. This helps compensate for sub-optimality of the acquisition + # function optimizer, without compromising the diversity of the + # suggestions in the feature space. + use_ucb = _has_new_completed_trials( + completed_trials=self._all_completed_trials, + active_trials=active_trials, + ) or (ucb_overwrite and self._all_completed_trials) + # TODO: Feed the eagle strategy with completed trials. + # TODO: Change budget based on requested suggestion count. + acquisition_optimizer = self._acquisition_optimizer_factory( + self._converter + ) + + if active_trials: + pending_features = self._converter.to_features(active_trials) + predictive_all_features = self._get_predictive_all_features( + pending_features, data, model + ) + else: + predictive_all_features = predictive + + # When `use_ucb` is true, the acquisition function computes the UCB + # values. Otherwise, it computes the Pure-Exploration acquisition values. + if use_ucb: + scoring_fn = UCBScoreFunction( + predictive, + predictive_all_features, + ucb_coefficient=self._config.ucb_coefficient, + trust_region=tr if self._use_trust_region else None, + ) + else: + scoring_fn = PEScoreFunction( + predictive, + predictive_all_features, + penalty_coefficient=self._config.cb_violation_penalty_coefficient, + ucb_coefficient=self._config.ucb_coefficient, + explore_ucb_coefficient=self._config.explore_region_ucb_coefficient, + trust_region=tr if self._use_trust_region else None, + ) + + if isinstance(acquisition_optimizer, vb.VectorizedOptimizer): + acq_rng, self._rng = jax.random.split(self._rng) + prior_features = None + if self._all_completed_trials: + prior_features = vb.trials_to_sorted_array( + self._all_completed_trials, self._converter + ) + with profiler.timeit('acquisition_optimizer', also_log=True): + best_candidates = eqx.filter_jit(acquisition_optimizer)( + scoring_fn.score, + prior_features=prior_features, + count=1, + seed=acq_rng, + score_with_aux_fn=scoring_fn.score_with_aux, + ) + jax.block_until_ready(best_candidates) + with profiler.timeit('best_candidates_to_trials', also_log=True): + best_candidate = vb.best_candidates_to_trials( + best_candidates, self._converter + )[0] + elif isinstance(acquisition_optimizer, vza.GradientFreeOptimizer): + # Seed the optimizer with previous trials. + acquisition = self.get_score_fn_on_trials(scoring_fn.score) + best_candidate = acquisition_optimizer.optimize( + acquisition, + acquisition_problem, + count=1, + seed_candidates=copy.deepcopy(self._all_completed_trials), + )[0] + else: + raise ValueError( + f'Unrecognized acquisition_optimizer: {type(acquisition_optimizer)}' + ) + + # Make predictions (in the warped space). + logging.info('Converting the optimization result into suggestion...') + optimal_features = self._converter.to_features([best_candidate]) # [1, D] + aux = eqx.filter_jit(scoring_fn.aux)(optimal_features) + predict_mean = aux['mean'] # [1,] + predict_stddev = aux['stddev'] # [1,] + predict_stddev_from_all = aux['stddev_from_all'] # [1,] + acquisition = best_candidate.final_measurement_or_die.metrics.get_value( + 'acquisition', float('nan') + ) + logging.info( + 'Created predictions for the best candidates which were converted to' + f' an array of shape: {_get_features_shape(optimal_features)}. mean' + f' has shape {predict_mean.shape}. stddev has shape' + f' {predict_stddev.shape}.stddev_from_all has shape' + f' {predict_stddev_from_all.shape}. acquisition value of' + f' best_candidate: {acquisition}, use_ucb: {use_ucb}' + ) + + # Create suggestions, injecting the predictions as metadata for + # debugging needs. + metadata = best_candidate.metadata.ns(self._metadata_ns) + metadata.ns('prediction_in_warped_y_space').update({ + 'mean': f'{predict_mean[0]}', + 'stddev': f'{predict_stddev[0]}', + 'stddev_from_all': f'{predict_stddev_from_all[0]}', + 'acquisition': f'{acquisition}', + 'use_ucb': f'{use_ucb}', + 'trust_radius': f'{tr.trust_radius}', + 'params': f'{model.params}', + }) + metadata.ns('timing').update( + {'time': f'{datetime.datetime.now() - begin}'} + ) + suggestions.append( + vz.TrialSuggestion( + best_candidate.parameters, metadata=best_candidate.metadata + ) + ) + active_trials.append(suggestions[-1].to_trial()) + + return suggestions diff --git a/vizier/_src/pyvizier/oss/study_config.py b/vizier/_src/pyvizier/oss/study_config.py index c2edfd0f8..1719d5a49 100644 --- a/vizier/_src/pyvizier/oss/study_config.py +++ b/vizier/_src/pyvizier/oss/study_config.py @@ -62,8 +62,10 @@ class Algorithm(enum.Enum): """Valid Values for StudyConfig.Algorithm.""" - # Let Vizier choose algorithm. Currently defaults to GAUSSIAN_PROCESS_BANDIT. + # Let Vizier choose algorithm. Currently defaults to GP_UCB_PE. ALGORITHM_UNSPECIFIED = 'ALGORITHM_UNSPECIFIED' + # Gaussian Process UCB with Pure Exploration. + GP_UCB_PE = 'GP_UCB_PE' # Gaussian Process Bandit. GAUSSIAN_PROCESS_BANDIT = 'GAUSSIAN_PROCESS_BANDIT' # Grid search within the feasible space. diff --git a/vizier/_src/service/policy_factory.py b/vizier/_src/service/policy_factory.py index 6346a9a37..9d17a90dd 100644 --- a/vizier/_src/service/policy_factory.py +++ b/vizier/_src/service/policy_factory.py @@ -15,6 +15,7 @@ from __future__ import annotations """Service-related policy factories.""" + # pylint:disable=g-import-not-at-top import functools import time @@ -39,8 +40,12 @@ def __call__( if algorithm in ( 'DEFAULT', 'ALGORITHM_UNSPECIFIED', - 'GAUSSIAN_PROCESS_BANDIT', + 'GP_UCB_PE', ): + from vizier._src.algorithms.designers import gp_ucb_pe + + return dp.DesignerPolicy(policy_supporter, gp_ucb_pe.VizierGPUCBPEBandit) + elif algorithm == 'GAUSSIAN_PROCESS_BANDIT': from vizier._src.algorithms.designers import gp_bandit return dp.DesignerPolicy( diff --git a/vizier/_src/service/vizier_client_test.py b/vizier/_src/service/vizier_client_test.py index a0d9c2f24..6cb074249 100644 --- a/vizier/_src/service/vizier_client_test.py +++ b/vizier/_src/service/vizier_client_test.py @@ -215,15 +215,15 @@ def test_get_suggestions(self): # Only test algorithms which don't depend on external libraries (except for # numpy). @parameterized.parameters( + dict(algorithm='DEFAULT'), + dict(algorithm=vz.Algorithm.ALGORITHM_UNSPECIFIED), dict(algorithm=vz.Algorithm.RANDOM_SEARCH), dict(algorithm=vz.Algorithm.QUASI_RANDOM_SEARCH), dict(algorithm=vz.Algorithm.GRID_SEARCH), dict(algorithm=vz.Algorithm.NSGA2, multiobj=True), - dict(algorithm=vz.Algorithm.ALGORITHM_UNSPECIFIED, multiobj=True), - dict(algorithm='DEFAULT', multiobj=True), - dict(algorithm='DEFAULT', multiobj=False), dict(algorithm=vz.Algorithm.GAUSSIAN_PROCESS_BANDIT, multiobj=True), dict(algorithm=vz.Algorithm.GAUSSIAN_PROCESS_BANDIT, multiobj=False), + dict(algorithm=vz.Algorithm.GP_UCB_PE), ) def test_e2e_tuning( self, diff --git a/vizier/algorithms/designers/__init__.py b/vizier/algorithms/designers/__init__.py index 192bf4c7a..ec1082064 100644 --- a/vizier/algorithms/designers/__init__.py +++ b/vizier/algorithms/designers/__init__.py @@ -21,6 +21,7 @@ from vizier._src.algorithms.designers.eagle_strategy.eagle_strategy import EagleStrategyDesigner from vizier._src.algorithms.designers.emukit import EmukitDesigner from vizier._src.algorithms.designers.gp_bandit import VizierGPBandit +from vizier._src.algorithms.designers.gp_ucb_pe import VizierGPUCBPEBandit from vizier._src.algorithms.designers.grid import GridSearchDesigner from vizier._src.algorithms.designers.harmonica import HarmonicaDesigner from vizier._src.algorithms.designers.quasi_random import QuasiRandomDesigner diff --git a/vizier/pyvizier/converters/padding.py b/vizier/pyvizier/converters/padding.py index 91ab5196d..3964b1e32 100644 --- a/vizier/pyvizier/converters/padding.py +++ b/vizier/pyvizier/converters/padding.py @@ -81,7 +81,7 @@ def _pad_trailing_dims( array, padded_shape, fill_value=fill_value ) - def pad_features(self, features: np.ndarray) -> types.PaddedArray: + def pad_features(self, features: types.Array) -> types.PaddedArray: """Pads features in to a `PaddedArray`.""" return self._pad_trailing_dims( features, [self._num_trials, self._num_features] @@ -89,7 +89,7 @@ def pad_features(self, features: np.ndarray) -> types.PaddedArray: def pad_labels( self, - labels: np.ndarray, + labels: types.Array, ) -> types.PaddedArray: """Pads labels in to a `PaddedArray`.""" return self._pad_trailing_dims(