Skip to content

Commit

Permalink
Remove Scalarize-then-UCB code (we have it for PyPI upload anyways). …
Browse files Browse the repository at this point in the history
…Rename ScalarizedAcquisition. Prepare for 0.1.18 which fixes quasi-random and designer.predict()

PiperOrigin-RevId: 659969619
  • Loading branch information
xingyousong authored and copybara-github committed Aug 6, 2024
1 parent a8cfa4b commit 5980ed4
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 34 deletions.
2 changes: 1 addition & 1 deletion vizier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@

sys.path.append(PROTO_ROOT)

__version__ = "0.1.17"
__version__ = "0.1.18"
3 changes: 1 addition & 2 deletions vizier/_src/algorithms/designers/gp/acquisitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def __call__(
# then we should make it `eqx.Module` and set
# `reduction_fn=eqx.field(static=True)` instead.
@struct.dataclass
class ScalarizedAcquisition(AcquisitionFunction):
class ScalarizeOverAcquisitions(AcquisitionFunction):
"""Wrapper that scalarizes a vectorized acquisition function."""

acquisition_fn: AcquisitionFunction
Expand Down Expand Up @@ -589,7 +589,6 @@ def __call__(
return self.reduction_fn(scalarized)


# TODO: Temporary for experimentation.
@struct.dataclass
class AcquisitionOverScalarized(AcquisitionFunction):
"""Wrapper that applies acquisition over a scalarized distribution."""
Expand Down
8 changes: 4 additions & 4 deletions vizier/_src/algorithms/designers/gp/acquisitions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ def test_scalarized_ucb(self):
weights=jnp.array([0.1, 0.2]), reference_point=reference_point
)

acq = acquisitions.ScalarizedAcquisition(ucb, scalarizer)
acq = acquisitions.ScalarizeOverAcquisitions(ucb, scalarizer)
self.assertAlmostEqual(
acq(tfd.Normal([0.1, 0.2], [0.1, 0.1])), jnp.array([1.0]), delta=1e-2
)

# Tests that the scalarized acquisition is larger with max_scalarized.
scalarized_labels = scalarizer(labels.unpad())
max_scalarized = jnp.max(scalarized_labels, axis=-1)
acq = acquisitions.ScalarizedAcquisition(
acq = acquisitions.ScalarizeOverAcquisitions(
ucb, scalarizer, max_scalarized=max_scalarized
)
self.assertAlmostEqual(
Expand All @@ -104,7 +104,7 @@ def test_ehvi_approximation(self):

# Tests that the scalarizer gives the approximate hypervolume with mean
# and uses constant rescaling of pi/4 for num_objs=2.
hypervolume = acquisitions.ScalarizedAcquisition(
hypervolume = acquisitions.ScalarizeOverAcquisitions(
acquisitions.UCB(coefficient=0.0),
scalarizer,
reduction_fn=lambda x: jnp.mean(x, axis=0),
Expand Down Expand Up @@ -154,7 +154,7 @@ def test_ehvi_mcmc(self):

# Tests that the scalarizer gives the approximate hypervolume with mean
# and uses constant rescaling of pi/4 for num_objs=2.
hypervolume = acquisitions.ScalarizedAcquisition(
hypervolume = acquisitions.ScalarizeOverAcquisitions(
acquisitions.Sample(num_samples=100),
scalarizer,
reduction_fn=lambda x: jnp.mean(jax.nn.relu(x)),
Expand Down
28 changes: 7 additions & 21 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import copy
import dataclasses
import datetime
import functools
import random
from typing import Optional, Sequence

Expand Down Expand Up @@ -146,12 +145,8 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
)

# Multi-objective parameters.
_num_samples: Optional[int] = attr.field(default=None, kw_only=True)
_num_scalarizations: int = attr.field(default=1000, kw_only=True)
_ref_scaling: float = attr.field(default=0.01, kw_only=True)
# Should be true generally, keeps track of maximum scalarized value in each
# direction for cumulative comparisons.
_use_max_scalarized: bool = attr.field(default=True, kw_only=True)

# ------------------------------------------------------------------
# Internal attributes which should not be set by callers.
Expand Down Expand Up @@ -217,20 +212,6 @@ def __attrs_post_init__(self):
weights = jnp.abs(weights)
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)

acquisition_fn = acq_lib.UCB()
if self._num_samples is None:
acq_factory = functools.partial(
acq_lib.ScalarizedAcquisition,
acquisition_fn=acquisition_fn,
reduction_fn=lambda x: jnp.mean(x, axis=0),
)
else:
acq_factory = functools.partial(
acq_lib.AcquisitionOverScalarized,
acquisition_fn=acquisition_fn,
num_samples=self._num_samples,
)

def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction:
# Scalarized UCB.
labels_array = data.labels.padded_array
Expand All @@ -243,10 +224,15 @@ def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction:
scalarizer = scalarization.HyperVolumeScalarization(weights, ref_point)

max_scalarized = None
if has_labels and self._use_max_scalarized:
if has_labels:
max_scalarized = jnp.max(scalarizer(labels_array), axis=-1)

return acq_factory(scalarizer=scalarizer, max_scalarized=max_scalarized)
return acq_lib.ScalarizeOverAcquisitions(
acquisition_fn=acq_lib.UCB(),
scalarizer=scalarizer,
reduction_fn=lambda x: jnp.mean(x, axis=0),
max_scalarized=max_scalarized,
)

self._scoring_function_factory = (
acq_lib.bayesian_scoring_function_factory(acq_fn_factory)
Expand Down
8 changes: 2 additions & 6 deletions vizier/_src/algorithms/designers/gp_bandit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,7 @@ def _qei_factory(data: types.ModelData) -> acquisitions.AcquisitionFunction:
iters * n_parallel,
)

@parameterized.parameters(
dict(num_samples=10),
dict(num_samples=None),
)
def test_multi_metrics(self, num_samples: int | None):
def test_multi_metrics(self):
search_space = vz.SearchSpace()
search_space.root.add_float_param('x0', -5.0, 5.0)
problem = vz.ProblemStatement(
Expand All @@ -493,7 +489,7 @@ def test_multi_metrics(self, num_samples: int | None):
)

iters = 2
designer = gp_bandit.VizierGPBandit(problem, num_samples=num_samples)
designer = gp_bandit.VizierGPBandit(problem)
self.assertLen(
test_runners.RandomMetricsRunner(
problem,
Expand Down

0 comments on commit 5980ed4

Please sign in to comment.