From 5980ed494ff702b89011470237698956c0619824 Mon Sep 17 00:00:00 2001 From: Xingyou Song Date: Tue, 6 Aug 2024 08:37:12 -0700 Subject: [PATCH] Remove Scalarize-then-UCB code (we have it for PyPI upload anyways). Rename ScalarizedAcquisition. Prepare for 0.1.18 which fixes quasi-random and designer.predict() PiperOrigin-RevId: 659969619 --- vizier/__init__.py | 2 +- .../algorithms/designers/gp/acquisitions.py | 3 +- .../designers/gp/acquisitions_test.py | 8 +++--- vizier/_src/algorithms/designers/gp_bandit.py | 28 +++++-------------- .../algorithms/designers/gp_bandit_test.py | 8 ++---- 5 files changed, 15 insertions(+), 34 deletions(-) diff --git a/vizier/__init__.py b/vizier/__init__.py index 15bd3bd6e..2aa1bbb09 100644 --- a/vizier/__init__.py +++ b/vizier/__init__.py @@ -23,4 +23,4 @@ sys.path.append(PROTO_ROOT) -__version__ = "0.1.17" +__version__ = "0.1.18" diff --git a/vizier/_src/algorithms/designers/gp/acquisitions.py b/vizier/_src/algorithms/designers/gp/acquisitions.py index 4ee7ec71d..4e9b37951 100644 --- a/vizier/_src/algorithms/designers/gp/acquisitions.py +++ b/vizier/_src/algorithms/designers/gp/acquisitions.py @@ -561,7 +561,7 @@ def __call__( # then we should make it `eqx.Module` and set # `reduction_fn=eqx.field(static=True)` instead. @struct.dataclass -class ScalarizedAcquisition(AcquisitionFunction): +class ScalarizeOverAcquisitions(AcquisitionFunction): """Wrapper that scalarizes a vectorized acquisition function.""" acquisition_fn: AcquisitionFunction @@ -589,7 +589,6 @@ def __call__( return self.reduction_fn(scalarized) -# TODO: Temporary for experimentation. @struct.dataclass class AcquisitionOverScalarized(AcquisitionFunction): """Wrapper that applies acquisition over a scalarized distribution.""" diff --git a/vizier/_src/algorithms/designers/gp/acquisitions_test.py b/vizier/_src/algorithms/designers/gp/acquisitions_test.py index ad9e3b693..442ec370c 100644 --- a/vizier/_src/algorithms/designers/gp/acquisitions_test.py +++ b/vizier/_src/algorithms/designers/gp/acquisitions_test.py @@ -77,7 +77,7 @@ def test_scalarized_ucb(self): weights=jnp.array([0.1, 0.2]), reference_point=reference_point ) - acq = acquisitions.ScalarizedAcquisition(ucb, scalarizer) + acq = acquisitions.ScalarizeOverAcquisitions(ucb, scalarizer) self.assertAlmostEqual( acq(tfd.Normal([0.1, 0.2], [0.1, 0.1])), jnp.array([1.0]), delta=1e-2 ) @@ -85,7 +85,7 @@ def test_scalarized_ucb(self): # Tests that the scalarized acquisition is larger with max_scalarized. scalarized_labels = scalarizer(labels.unpad()) max_scalarized = jnp.max(scalarized_labels, axis=-1) - acq = acquisitions.ScalarizedAcquisition( + acq = acquisitions.ScalarizeOverAcquisitions( ucb, scalarizer, max_scalarized=max_scalarized ) self.assertAlmostEqual( @@ -104,7 +104,7 @@ def test_ehvi_approximation(self): # Tests that the scalarizer gives the approximate hypervolume with mean # and uses constant rescaling of pi/4 for num_objs=2. - hypervolume = acquisitions.ScalarizedAcquisition( + hypervolume = acquisitions.ScalarizeOverAcquisitions( acquisitions.UCB(coefficient=0.0), scalarizer, reduction_fn=lambda x: jnp.mean(x, axis=0), @@ -154,7 +154,7 @@ def test_ehvi_mcmc(self): # Tests that the scalarizer gives the approximate hypervolume with mean # and uses constant rescaling of pi/4 for num_objs=2. - hypervolume = acquisitions.ScalarizedAcquisition( + hypervolume = acquisitions.ScalarizeOverAcquisitions( acquisitions.Sample(num_samples=100), scalarizer, reduction_fn=lambda x: jnp.mean(jax.nn.relu(x)), diff --git a/vizier/_src/algorithms/designers/gp_bandit.py b/vizier/_src/algorithms/designers/gp_bandit.py index be1f426ef..698655251 100644 --- a/vizier/_src/algorithms/designers/gp_bandit.py +++ b/vizier/_src/algorithms/designers/gp_bandit.py @@ -24,7 +24,6 @@ import copy import dataclasses import datetime -import functools import random from typing import Optional, Sequence @@ -146,12 +145,8 @@ class VizierGPBandit(vza.Designer, vza.Predictor): ) # Multi-objective parameters. - _num_samples: Optional[int] = attr.field(default=None, kw_only=True) _num_scalarizations: int = attr.field(default=1000, kw_only=True) _ref_scaling: float = attr.field(default=0.01, kw_only=True) - # Should be true generally, keeps track of maximum scalarized value in each - # direction for cumulative comparisons. - _use_max_scalarized: bool = attr.field(default=True, kw_only=True) # ------------------------------------------------------------------ # Internal attributes which should not be set by callers. @@ -217,20 +212,6 @@ def __attrs_post_init__(self): weights = jnp.abs(weights) weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True) - acquisition_fn = acq_lib.UCB() - if self._num_samples is None: - acq_factory = functools.partial( - acq_lib.ScalarizedAcquisition, - acquisition_fn=acquisition_fn, - reduction_fn=lambda x: jnp.mean(x, axis=0), - ) - else: - acq_factory = functools.partial( - acq_lib.AcquisitionOverScalarized, - acquisition_fn=acquisition_fn, - num_samples=self._num_samples, - ) - def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction: # Scalarized UCB. labels_array = data.labels.padded_array @@ -243,10 +224,15 @@ def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction: scalarizer = scalarization.HyperVolumeScalarization(weights, ref_point) max_scalarized = None - if has_labels and self._use_max_scalarized: + if has_labels: max_scalarized = jnp.max(scalarizer(labels_array), axis=-1) - return acq_factory(scalarizer=scalarizer, max_scalarized=max_scalarized) + return acq_lib.ScalarizeOverAcquisitions( + acquisition_fn=acq_lib.UCB(), + scalarizer=scalarizer, + reduction_fn=lambda x: jnp.mean(x, axis=0), + max_scalarized=max_scalarized, + ) self._scoring_function_factory = ( acq_lib.bayesian_scoring_function_factory(acq_fn_factory) diff --git a/vizier/_src/algorithms/designers/gp_bandit_test.py b/vizier/_src/algorithms/designers/gp_bandit_test.py index 41fe101e8..582c6399a 100644 --- a/vizier/_src/algorithms/designers/gp_bandit_test.py +++ b/vizier/_src/algorithms/designers/gp_bandit_test.py @@ -471,11 +471,7 @@ def _qei_factory(data: types.ModelData) -> acquisitions.AcquisitionFunction: iters * n_parallel, ) - @parameterized.parameters( - dict(num_samples=10), - dict(num_samples=None), - ) - def test_multi_metrics(self, num_samples: int | None): + def test_multi_metrics(self): search_space = vz.SearchSpace() search_space.root.add_float_param('x0', -5.0, 5.0) problem = vz.ProblemStatement( @@ -493,7 +489,7 @@ def test_multi_metrics(self, num_samples: int | None): ) iters = 2 - designer = gp_bandit.VizierGPBandit(problem, num_samples=num_samples) + designer = gp_bandit.VizierGPBandit(problem) self.assertLen( test_runners.RandomMetricsRunner( problem,