Skip to content

Commit

Permalink
Adds sample and predict methods to GP-UCB-PE designer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 677878382
  • Loading branch information
vizier-team authored and copybara-github committed Sep 23, 2024
1 parent eaac694 commit 1438657
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 16 deletions.
15 changes: 4 additions & 11 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,18 +580,11 @@ def sample(
samples = samples[
:, ~(xs.continuous.is_missing[0] | xs.categorical.is_missing[0])
]
unwarped_samples = None
# TODO: vectorize output warping.
for i in range(samples.shape[0]):
unwarp_samples_ = self._output_warper.unwarp(
samples[i][..., np.newaxis]
).reshape(-1)
if unwarped_samples is not None:
unwarped_samples = np.vstack([unwarp_samples_, unwarped_samples])
else:
unwarped_samples = unwarp_samples_

return unwarped_samples # pytype: disable=bad-return-type
return np.vstack([
self._output_warper.unwarp(samples[i][..., np.newaxis]).reshape(-1)
for i in range(samples.shape[0])
])

@profiler.record_runtime
def predict(
Expand Down
84 changes: 81 additions & 3 deletions vizier/_src/algorithms/designers/gp_ucb_pe.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ def __attrs_post_init__(self):
self._problem.search_space,
seed=int(jax.random.randint(qrs_seed, [], 0, 2**16)),
)
self._output_warper = None

def update(
self, completed: vza.CompletedTrials, all_active: vza.ActiveTrials
Expand Down Expand Up @@ -706,9 +707,8 @@ def _trials_to_data(self, trials: Sequence[vz.Trial]) -> types.ModelData:
data.labels.shape,
_get_features_shape(data.features),
)
warped_labels = output_warpers.create_default_warper().warp(
np.array(data.labels.unpad())
)
self._output_warper = output_warpers.create_default_warper()
warped_labels = self._output_warper.warp(np.array(data.labels.unpad()))
labels = types.PaddedArray.from_array(
warped_labels,
data.labels.padded_array.shape,
Expand Down Expand Up @@ -1013,6 +1013,84 @@ def _suggest_batch_with_exploration(
)
return suggestions

@profiler.record_runtime
def sample(
self,
trials: Sequence[vz.TrialSuggestion],
rng: Optional[jax.Array] = None,
num_samples: int = 1000,
) -> types.Array:
"""Returns unwarped samples from the model for any given trials.
Arguments:
trials: The trials where the predictions will be made.
rng: The sampling random key.
num_samples: The number of samples per trial.
Returns:
The samples in the specified trials. shape: (num_samples, num_trials)
"""
if rng is None:
rng = jax.random.PRNGKey(0)

if not trials:
return np.zeros((num_samples, 0))

data = self._trials_to_data(self._all_completed_trials)
self._rng, ard_rng = jax.random.split(self._rng, 2)
model = self._build_gp_model_and_optimize_parameters(data, ard_rng)
predictive = sp.UniformEnsemblePredictive(
predictives=eqx.filter_jit(model.precompute_predictive)(data)
)

xs = self._converter.to_features(trials)
xs = types.ModelInput(
continuous=xs.continuous.replace_fill_value(0.0),
categorical=xs.categorical.replace_fill_value(0),
)
samples = eqx.filter_jit(acquisitions.sample_from_predictive)(
predictive, xs, num_samples, key=rng
) # (num_samples, num_trials)
# Scope the samples to non-padded only (there's a single padded dimension).
samples = samples[
:, ~(xs.continuous.is_missing[0] | xs.categorical.is_missing[0])
]
# TODO: vectorize output warping.
if self._output_warper is not None:
return np.vstack([
self._output_warper.unwarp(samples[i][..., np.newaxis]).reshape(-1)
for i in range(samples.shape[0])
])
else:
raise TypeError(
'Output warper is expected to be set, but found to be None.'
)

@profiler.record_runtime
def predict(
self,
trials: Sequence[vz.TrialSuggestion],
rng: Optional[jax.Array] = None,
num_samples: Optional[int] = 1000,
) -> vza.Prediction:
"""Returns the mean and stddev for any given trials.
The method performs sampling of the warped GP model, unwarp the samples and
compute the empirical mean and standard deviation as an apprixmation.
Arguments:
trials: The trials where the predictions will be made.
rng: The sampling random key used for approximation.
num_samples: The number of samples used for the approximation.
Returns:
The predictions in the specified trials.
"""
unwarped_samples = self.sample(trials, rng, num_samples)
mean = np.mean(unwarped_samples, axis=0)
stddev = np.std(unwarped_samples, axis=0)
return vza.Prediction(mean=mean, stddev=stddev)

@profiler.record_runtime(name_prefix='VizierGPUCBPEBandit', name='suggest')
def suggest(
self, count: Optional[int] = None
Expand Down
47 changes: 45 additions & 2 deletions vizier/_src/algorithms/designers/gp_ucb_pe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vizier import pyvizier as vz
from vizier._src.algorithms.core import abstractions
from vizier._src.algorithms.designers import gp_ucb_pe
from vizier._src.algorithms.designers import quasi_random
from vizier._src.algorithms.optimizers import eagle_strategy as es
from vizier._src.algorithms.optimizers import vectorized_base as vb
from vizier.jax import optimizers
Expand Down Expand Up @@ -119,7 +120,10 @@ def test_on_flat_space(
config=gp_ucb_pe.UCBPEConfig(
ucb_coefficient=10.0,
explore_region_ucb_coefficient=0.5,
cb_violation_penalty_coefficient=10.0,
# Sets the penalty coefficient to 0.0 so that the PE aquisition
# value is exactly the standard deviation prediction based on all
# trials.
cb_violation_penalty_coefficient=0.0,
ucb_overwrite_probability=0.0,
pe_overwrite_probability=1.0 if pe_overwrite else 0.0,
# In high noise mode, the PE acquisition function is always used.
Expand All @@ -140,9 +144,17 @@ def test_on_flat_space(
rng=jax.random.PRNGKey(1),
)

quasi_random_sampler = quasi_random.QuasiRandomDesigner(
problem.search_space,
)
test_trials = quasi_random_sampler.suggest(count=3)

all_active_trials = []
all_trials = []
trial_id = 1
last_prediction = None
last_samples = None
label_rng = jax.random.PRNGKey(1)
# Simulates batch suggestions with delayed feedback: the first two batches
# are generated by the designer without any completed trials (but all with
# active trials). Starting from the third batch, the oldest batch gets
Expand All @@ -165,9 +177,10 @@ def test_on_flat_space(
for _ in range(batch_size):
measurement = vz.Measurement()
for mi in problem.metric_information:
label_rng, rng = jax.random.split(label_rng, 2)
measurement.metrics[mi.name] = float(
jax.random.uniform(
jax.random.PRNGKey(1),
rng,
minval=mi.min_value_or(lambda: -10.0),
maxval=mi.max_value_or(lambda: 10.0),
)
Expand All @@ -179,6 +192,36 @@ def test_on_flat_space(
completed=abstractions.CompletedTrials(completed_trials),
all_active=abstractions.ActiveTrials(all_active_trials),
)
# After the designer is updated with completed trials, prediction and
# sampling results are expected to change.
if len(completed_trials) > 1:
# test the sample method.
samples = designer.sample(test_trials, num_samples=5)
self.assertSequenceEqual(samples.shape, (5, 3))
self.assertFalse(np.isnan(samples).any())
# test the sample method with a different rng.
samples_rng = designer.sample(
test_trials, num_samples=5, rng=jax.random.PRNGKey(1)
)
self.assertFalse(np.isnan(samples_rng).any())
self.assertFalse((np.abs(samples - samples_rng) <= 1e-6).all())
# test the predict method.
prediction = designer.predict(test_trials)
self.assertLen(prediction.mean, 3)
self.assertLen(prediction.stddev, 3)
self.assertFalse(np.isnan(prediction.mean).any())
self.assertFalse(np.isnan(prediction.stddev).any())
if last_prediction is None:
last_prediction = prediction
last_samples = samples
else:
self.assertFalse(
(np.abs(last_prediction.mean - prediction.mean) <= 1e-6).all()
)
self.assertFalse(
(np.abs(last_prediction.stddev - prediction.stddev) <= 1e-6).all()
)
self.assertFalse((np.abs(last_samples - samples) <= 1e-6).all())

self.assertLen(all_trials, (iters + 2) * batch_size)

Expand Down

0 comments on commit 1438657

Please sign in to comment.