Skip to content

Commit

Permalink
Beautify GP algs a bit
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609909172
  • Loading branch information
xingyousong authored and copybara-github committed Feb 24, 2024
1 parent 4f8cbb4 commit 1abf2eb
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 87 deletions.
100 changes: 39 additions & 61 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
A Python implementation of Google Vizier's GP-Bandit algorithm.
"""

# pylint: disable=logging-fstring-interpolation, g-long-lambda

import copy
Expand All @@ -35,7 +36,7 @@
from vizier import pyvizier as vz
from vizier._src.algorithms.designers import quasi_random
from vizier._src.algorithms.designers import scalarization
from vizier._src.algorithms.designers.gp import acquisitions
from vizier._src.algorithms.designers.gp import acquisitions as acq_lib
from vizier._src.algorithms.designers.gp import gp_models
from vizier._src.algorithms.designers.gp import output_warpers
from vizier._src.algorithms.optimizers import eagle_strategy as es
Expand All @@ -56,8 +57,8 @@
suggestion_batch_size=25,
)

default_scoring_function_factory = (
acquisitions.bayesian_scoring_function_factory(lambda _: acquisitions.UCB())
default_scoring_function_factory = acq_lib.bayesian_scoring_function_factory(
lambda _: acq_lib.UCB()
)


Expand Down Expand Up @@ -118,7 +119,7 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
)
_num_seed_trials: int = attr.field(default=1, kw_only=True)
_linear_coef: float = attr.field(default=0.0, kw_only=True)
_scoring_function_factory: acquisitions.ScoringFunctionFactory = attr.field(
_scoring_function_factory: acq_lib.ScoringFunctionFactory = attr.field(
factory=lambda: default_scoring_function_factory,
kw_only=True,
)
Expand Down Expand Up @@ -183,7 +184,7 @@ def __attrs_post_init__(self):
self._acquisition_optimizer = self._acquisition_optimizer_factory(
self._converter
)
acquisition_problem = copy.deepcopy(self._problem)
self._acquisition_problem = copy.deepcopy(self._problem)
empty_data = types.ModelData(
features=self._converter.to_features([]),
labels=types.PaddedArray.as_padded(
Expand All @@ -201,34 +202,25 @@ def __attrs_post_init__(self):
empty_data, predictive, self._use_trust_region
)
if (
isinstance(scoring_fn, acquisitions.MaxValueEntropySearch)
isinstance(scoring_fn, acq_lib.MaxValueEntropySearch)
and self._ensemble_size > 1
):
raise ValueError(
'MaxValueEntropySearch is not supported with ensemble '
'size greater than one.'
)

acquisition_function = getattr(scoring_fn, 'acquisition_fn', None)
if isinstance(acquisition_function, acquisitions.MultiAcquisitionFunction):
acquisition_config = vz.MetricsConfig()
self._acquisition_problem.metric_information = vz.MetricsConfig()
if isinstance(acquisition_function, acq_lib.MultiAcquisitionFunction):
for k in acquisition_function.acquisition_fns.keys():
acquisition_config.append(
vz.MetricInformation(
name=k,
goal=vz.ObjectiveMetricGoal.MAXIMIZE,
)
)
metric = vz.MetricInformation(k, goal=vz.ObjectiveMetricGoal.MAXIMIZE)
self._acquisition_problem.metric_information.append(metric)
else:
acquisition_config = vz.MetricsConfig(
metrics=[
vz.MetricInformation(
name='acquisition', goal=vz.ObjectiveMetricGoal.MAXIMIZE
)
]
metric = vz.MetricInformation(
'acquisition', goal=vz.ObjectiveMetricGoal.MAXIMIZE
)

acquisition_problem.metric_information = acquisition_config
self._acquisition_problem = acquisition_problem
self._acquisition_problem.metric_information.append(metric)

def update(
self, completed: vza.CompletedTrials, all_active: vza.ActiveTrials
Expand Down Expand Up @@ -295,23 +287,18 @@ def _generate_seed_trials(self, count: int) -> Sequence[vz.TrialSuggestion]:
# NOTE: The code below assumes that a scaled value of 0.5 corresponds
# to the center of the feasible range. This is true, but only by
# accident; ideally, we should get the center from the converters.
parameters = self._converter.to_parameters(
types.ModelInput(
continuous=self._padding_schedule.pad_features(
0.5 * np.ones([1, features.continuous.shape[1]])
),
categorical=self._padding_schedule.pad_features(
np.zeros(
[1, features.categorical.shape[1]], dtype=types.INT_DTYPE
)
),
)
)[0]
seed_suggestions.append(
vz.TrialSuggestion(
parameters, metadata=vz.Metadata({'seeded': 'center'})
)
continuous = self._padding_schedule.pad_features(
0.5 * np.ones([1, features.continuous.shape[1]])
)
categorical = self._padding_schedule.pad_features(
np.zeros([1, features.categorical.shape[1]], dtype=types.INT_DTYPE)
)
model_input = types.ModelInput(continuous, categorical)
parameters = self._converter.to_parameters(model_input)[0]
suggestion = vz.TrialSuggestion(
parameters, metadata=vz.Metadata({'seeded': 'center'})
)
seed_suggestions.append(suggestion)
with profiler.timeit('quasi_random_sampler_seed_trials'):
if (remaining_counts := count - len(seed_suggestions)) > 0:
seed_suggestions.extend(
Expand Down Expand Up @@ -435,7 +422,7 @@ def _update_gp(self, data: types.ModelData) -> gp_models.GPState:
@_experimental_override_allowed
@profiler.record_runtime
def _optimize_acquisition(
self, scoring_fn: acquisitions.BayesianScoringFunction, count: int
self, scoring_fn: acq_lib.BayesianScoringFunction, count: int
) -> list[vz.Trial]:
jax.monitoring.record_event(
'/vizier/jax/gp_bandit/optimize_acquisition/called'
Expand Down Expand Up @@ -464,17 +451,15 @@ def _optimize_acquisition(
n_parallel=n_parallel,
)

optimal_features = best_candidates.features
best_candidates = dataclasses.replace(
best_candidates, features=optimal_features
best_candidates, features=best_candidates.features
)

# Convert best_candidates (in scaled space) into suggestions (in unscaled
# space); also append debug information like model predictions.
# space); also append debug information like model predictions. Output shape
# [N, D].
logging.info('Converting the optimization result into suggestions...')
return vb.best_candidates_to_trials(
best_candidates, self._converter
) # [N, D]
return vb.best_candidates_to_trials(best_candidates, self._converter)

@profiler.record_runtime
def suggest(self, count: int = 1) -> Sequence[vz.TrialSuggestion]:
Expand Down Expand Up @@ -541,7 +526,7 @@ def sample(
continuous=xs.continuous.replace_fill_value(0.0),
categorical=xs.categorical.replace_fill_value(0),
)
samples = eqx.filter_jit(acquisitions.sample_from_predictive)(
samples = eqx.filter_jit(acq_lib.sample_from_predictive)(
gp, xs, num_samples, key=rng
) # (num_samples, num_trials)
# Scope the samples to non-padded only (there's a single padded dimension).
Expand Down Expand Up @@ -596,28 +581,21 @@ def from_problem(
if problem.is_single_objective:
return cls(problem, rng=rng, linear_coef=1.0)
else:
num_objectives = len(
problem.metric_information.of_type(vz.MetricType.OBJECTIVE)
)
random_weights = np.abs(np.random.normal(size=num_objectives))
objectives = problem.metric_information.of_type(vz.MetricType.OBJECTIVE)
random_weights = np.abs(np.random.normal(size=len(objectives)))

def _scalarized_ucb(
data: types.ModelData,
) -> acquisitions.AcquisitionFunction:
def _scalarized_ucb(data: types.ModelData) -> acq_lib.AcquisitionFunction:
del data
ucb = acquisitions.UCB()
scalarizer = scalarization.HyperVolumeScalarization(
weights=random_weights
)
return acquisitions.ScalarizedAcquisition(ucb, scalarizer)
scalarizer = scalarization.HyperVolumeScalarization(random_weights)
return acq_lib.ScalarizedAcquisition(acq_lib.UCB(), scalarizer)

scoring_fn_factory = acquisitions.bayesian_scoring_function_factory(
scoring_function_factory = acq_lib.bayesian_scoring_function_factory(
_scalarized_ucb
)
return cls(
problem,
linear_coef=1.0,
scoring_function_factory=scoring_fn_factory,
scoring_function_factory=scoring_function_factory,
scoring_function_is_parallel=True,
use_trust_region=False,
)
42 changes: 16 additions & 26 deletions vizier/_src/algorithms/designers/gp_ucb_pe.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,27 +433,21 @@ def _generate_seed_trials(self, count: int) -> Sequence[vz.TrialSuggestion]:
# NOTE: The code below assumes that a scaled value of 0.5 corresponds
# to the center of the feasible range. This is true, but only by accident;
# ideally, we should get the center from the converters.
parameters = self._converter.to_parameters(
types.ModelInput(
continuous=self._padding_schedule.pad_features(
0.5 * np.ones([1, features.continuous.shape[1]])
),
categorical=self._padding_schedule.pad_features(
np.zeros(
[1, features.categorical.shape[1]], dtype=types.INT_DTYPE
)
),
)
)[0]
seed_suggestions.append(
vz.TrialSuggestion(
parameters, metadata=vz.Metadata({'seeded': 'center'})
)
continuous = self._padding_schedule.pad_features(
0.5 * np.ones([1, features.continuous.shape[1]])
)
if (remaining_counts := count - len(seed_suggestions)) > 0:
seed_suggestions.extend(
self._quasi_random_sampler.suggest(remaining_counts)
categorical = self._padding_schedule.pad_features(
np.zeros([1, features.categorical.shape[1]], dtype=types.INT_DTYPE)
)
model_input = types.ModelInput(continuous, categorical)
parameters = self._converter.to_parameters(model_input)[0]
suggestion = vz.TrialSuggestion(
parameters, metadata=vz.Metadata({'seeded': 'center'})
)
seed_suggestions.append(suggestion)
if (remaining_counts := count - len(seed_suggestions)) > 0:
quasi_suggestions = self._quasi_random_sampler.suggest(remaining_counts)
seed_suggestions.extend(quasi_suggestions)
return seed_suggestions

@profiler.record_runtime(
Expand Down Expand Up @@ -539,9 +533,7 @@ def _build_gp_model_and_optimize_parameters(
)

logging.info('Optimal parameters: %s', optimal_params)
return sp.StochasticProcessWithCoroutine(
coroutine=coroutine, params=optimal_params
)
return sp.StochasticProcessWithCoroutine(coroutine, optimal_params)

def get_score_fn_on_trials(
self, score_fn: Callable[[types.ModelInput], jax.Array]
Expand Down Expand Up @@ -643,10 +635,8 @@ def suggest(
self, count: Optional[int] = None
) -> Sequence[vz.TrialSuggestion]:
count = count or 1
if (
len(self._all_completed_trials) + len(self._all_active_trials)
< self._num_seed_trials
):
num_total = len(self._all_completed_trials) + len(self._all_active_trials)
if num_total < self._num_seed_trials:
return self._generate_seed_trials(count)

if self._clear_jax_cache:
Expand Down

0 comments on commit 1abf2eb

Please sign in to comment.