Skip to content

Commit

Permalink
Clarify and consolidate convergence curve scores
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 619602169
  • Loading branch information
belenkil authored and copybara-github committed Mar 27, 2024
1 parent 0427719 commit 6da1243
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 158 deletions.
197 changes: 82 additions & 115 deletions vizier/_src/benchmarks/analyzers/convergence_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import copy
import enum
import logging
from typing import Callable, List, Optional, Protocol, Sequence, Union
from typing import Callable, List, Literal, Optional, Protocol, Sequence, Union

import attr
import numpy as np
Expand Down Expand Up @@ -633,43 +633,41 @@ def curve(self) -> ConvergenceCurve:
def name(self) -> str:
return self._name

def _standardize_curves(
def standardize_curves(
self,
xs_cutoff: Optional[float] = None,
apply_quantiles: bool = True,
) -> tuple[np.ndarray, np.ndarray]:
"""Standardize convergence curves.
Note: This is an helper function that the class implementing this interface
can choose to use or not.
1. Align xs and keep each ys.
2. Convert curves to INCREASING.
3. Apply quantiles and impute NaN.
3. Apply quantiles and impute NaN (optional).
4. Remove values where xs < xs_cutoff.
Args:
xs_cutoff: The xs value before which values are ignored.
apply_quantiles: Whether to compute quantiles on the batches.
Returns:
The standardize baseline and compared curves.
The standardize baseline and compared curves. If apply_quantiles=True, the
shape is (num_steps,), otherwise the shape is (batch_size, num_steps).
"""
# Align the curves while keeping each ys.
align_baseline_curve, align_compared_curve = ConvergenceCurve.align_xs(
[self._baseline_curve, self._compared_curve],
interpolate_repeats=False,
keep_curves_separate=True,
)
# Apply batch quantiels.
baseline_ys = np.nanquantile(
self._sign * align_baseline_curve.ys,
self._baseline_quantile,
axis=0,
)
compared_ys = np.nanquantile(
self._sign * align_compared_curve.ys,
self._compared_quantile,
axis=0,
)
# Adjust sign to increasing.
baseline_ys = self._sign * align_baseline_curve.ys
compared_ys = self._sign * align_compared_curve.ys

if apply_quantiles:
# Apply batch quantiels (notice the dimension reduction).
baseline_ys = np.nanquantile(baseline_ys, self._baseline_quantile, axis=0)
compared_ys = np.nanquantile(compared_ys, self._compared_quantile, axis=0)

# Impute NaN values as -inf. This happens due to `align_xs` assigning
# np.nan for xs that are outside the original convergence curve.
baseline_ys = np.nan_to_num(baseline_ys, nan=-np.inf)
Expand Down Expand Up @@ -825,38 +823,6 @@ def score(self) -> float:
return self.summary_function(diff)


@attr.define
class StandardizedWinRateConvergenceCurveComparator(ConvergenceComparator):
"""Comparator method based on win rate on the standardized curves.
Attributes:
burn_cutoff: The cutoff below which values not included in score.
"""

_xs_cutoff: Optional[float] = None

def score(self) -> float:
"""Computes the standardized win-rate convergence score.
The score is the percentage of indices (after the burn cutoff) for which the
interpolated values of 'compared_curve' are better than the interpolated
values of 'baseline_curve'. A large score value means 'compared' is better.
Returns:
The percentage better convergence score [0.0, 1.0].
Raises:
ValueError: If curve trends are not INCREASING or DECREASING, or not
equal.
"""
baseline_ys, compared_ys = self._standardize_curves(self._xs_cutoff)
# Compute mean indices that compared is better than baseline.
return np.mean(baseline_ys < compared_ys)

def curve(self) -> ConvergenceCurve:
raise NotImplementedError('Curve not yet implemented.')


@attr.define
class PercentageBetterConvergenceCurveComparator(ConvergenceComparator):
"""Comparator method based on percentage better.
Expand All @@ -869,7 +835,7 @@ class PercentageBetterConvergenceCurveComparator(ConvergenceComparator):
reached that value 7 steps before.
Attributes:
burn_cutoff: The cutoff below which values not included in score.
xs_cutoff: The cutoff below which values not included in score.
"""

_xs_cutoff: Optional[float] = None
Expand Down Expand Up @@ -925,7 +891,7 @@ def score(self) -> float:
ValueError: If curve trends are not INCREASING or DECREASING, or not
equal.
"""
baseline_ys, compared_ys = self._standardize_curves(self._xs_cutoff)
baseline_ys, compared_ys = self.standardize_curves(self._xs_cutoff)
baseline_compared_score = self._compute_directional_score(
baseline_ys, compared_ys
)
Expand All @@ -938,54 +904,72 @@ def curve(self) -> ConvergenceCurve:
raise NotImplementedError('Curve not yet implemented.')


class WinRateComparator(ConvergenceComparator):
"""Comparator method based on simple win rate comparison."""
@attr.define
class WinRateConvergenceCurveComparator(ConvergenceComparator):
"""Comparator method based on convergence curves simple win rate comparison.
The comparator has two modes of comparing convergence curves:
1. Pairwise - Compare all pairs of repeated convergence curves and then
compute the mean win-rate over all the steps (i.e. trial) and pairs.
2. Quantiles - First compute the quantiles convergence curve per step across
the repeates, and then compute the mean win-rate over the steps.
The score ranges within [-0.5, 0.5], such that a score of 0.5 indicates that
the 'compared' curve is better than 'baseline' across all stpes.
"""

comparison_mode: Literal['pairwise', 'quantiles'] = 'pairwise'

def score(self) -> float:
return np.nanmean(self.curve().ys)

def curve(self) -> ConvergenceCurve:
"""Computes the curve that represents the average win rate."""
baseline_ys = (
self._sign * self._baseline_curve.ys
) # [N x T] array where N is the number of curves.
compared_ys = self._sign * self._compared_curve.ys

# Compares all pairs of compared to baseline curve.
all_comparisons = np.apply_along_axis(
lambda base: np.mean(compared_ys > base, axis=0)
+ 0.5 * np.mean(base == compared_ys, axis=0),
axis=1,
arr=baseline_ys,
)
# Note that 0.5 is the natural average, so subtracting it to make
# positive/negative score imply better/worse comparison.
return ConvergenceCurve(
xs=self._baseline_curve.xs,
ys=np.mean(all_comparisons, axis=0, keepdims=True) - 0.5,
)
if self.comparison_mode == 'pairwise':
baseline_ys, compared_ys = self.standardize_curves(apply_quantiles=False)
# Compares all pairs of compared to baseline curve.
all_comparisons = np.apply_along_axis(
lambda base: np.mean(compared_ys > base, axis=0)
+ 0.5 * np.mean(base == compared_ys, axis=0),
axis=1,
arr=baseline_ys,
)
curve_ys = np.mean(all_comparisons, axis=0, keepdims=True) - 0.5
elif self.comparison_mode == 'quantiles':
baseline_ys, compared_ys = self.standardize_curves(apply_quantiles=True)
curve_ys = np.asarray(
compared_ys > baseline_ys, dtype='float'
) + 0.5 * np.asarray(compared_ys == baseline_ys, dtype='float')
# Note that 0.5 is the natural average, so subtracting it to make
# positive/negative score imply better/worse comparison.
curve_ys = curve_ys[np.newaxis, :] - 0.5
else:
raise ValueError(f'Unknown comparison mode: {self.comparison_mode}')

return ConvergenceCurve(xs=self._baseline_curve.xs, ys=curve_ys)

class WinRateSimpleRegretComparator(ConvergenceComparator):
"""Comparator method based on win-rate simple regert."""

@attr.define
class OptimalityGapWinRateComparator(ConvergenceComparator):
"""Comparator method based on win-rate of the optimality gap."""

def score(self):
"""Computes the normalized simple regert score."""
baseline_ys, compared_ys = self._standardize_curves()
print('compared_ys:', compared_ys)
print('baseline_ys:', baseline_ys)
baseline_ys, compared_ys = self.standardize_curves()
return float(compared_ys[-1] > baseline_ys[-1])

def curve(self) -> ConvergenceCurve:
"""Returns a score curve for each xs."""
raise NotImplementedError('Curve not yet implemented.')


class NormalizedSimpleRegretComparator(ConvergenceComparator):
"""Comparator method based on normalized simple regert.
class OptimalityGapGainComparator(ConvergenceComparator):
"""Comparator method based on optimality gap gain.
The simple regret gain ('compared' - 'baseline') is normalized by the
'baseline' absolute simple regret and then truncated.
The optimality gap gain ('compared' - 'baseline') is normalized by the
'baseline' absolute optimality gap and then truncated.
"""

min_value: float = -0.5
Expand All @@ -994,7 +978,7 @@ class NormalizedSimpleRegretComparator(ConvergenceComparator):

def score(self):
"""Computes the normalized simple regert score."""
baseline_ys, compared_ys = self._standardize_curves()
baseline_ys, compared_ys = self.standardize_curves()
d = (compared_ys[-1] - baseline_ys[-1]) / (abs(baseline_ys[-1]) + self.eps)
return min(max(d, self.min_value), self.max_value)

Expand All @@ -1003,8 +987,8 @@ def curve(self) -> ConvergenceCurve:
raise NotImplementedError('Curve not yet implemented.')


class WinRateSimpleRegretComparatorFactory(ConvergenceComparatorFactory):
"""Factory class for WinRateSimpleRegretComparator."""
class OptimalityGapWinRateComparatorFactory(ConvergenceComparatorFactory):
"""Factory class for OptimalityGapWinRateComparator."""

def __call__(
self,
Expand All @@ -1013,17 +997,17 @@ def __call__(
baseline_quantile: float = 0.5,
compared_quantile: float = 0.5,
) -> ConvergenceComparator:
return WinRateSimpleRegretComparator(
return OptimalityGapWinRateComparator(
baseline_curve=baseline_curve,
compared_curve=compared_curve,
baseline_quantile=baseline_quantile,
compared_quantile=compared_quantile,
name='win_rate_simple_regret',
name='optimality_gap_win_rate',
)


class NormalizedSimpleRegretComparatorFactory(ConvergenceComparatorFactory):
"""Factory class for NormalizedSimpleRegretComparator."""
class OptimalityGapGainComparatorFactory(ConvergenceComparatorFactory):
"""Factory class for OptimalityGapGainComparator."""

def __call__(
self,
Expand All @@ -1032,17 +1016,20 @@ def __call__(
baseline_quantile: float = 0.5,
compared_quantile: float = 0.5,
) -> ConvergenceComparator:
return NormalizedSimpleRegretComparator(
return OptimalityGapGainComparator(
baseline_curve=baseline_curve,
compared_curve=compared_curve,
baseline_quantile=baseline_quantile,
compared_quantile=compared_quantile,
name='normalized_simple_regret',
name='optimality_gap_gain',
)


class WinRateComparatorFactory(ConvergenceComparatorFactory):
"""Factory class for WinRateComparatorFactory."""
@attr.define
class WinRateConvergenceCurveComparatorFactory(ConvergenceComparatorFactory):
"""Factory class for WinRateConvergenceCurveComparatorFactory."""

comparison_mode: Literal['pairwise', 'quantiles'] = 'pairwise'

def __call__(
self,
Expand All @@ -1051,12 +1038,13 @@ def __call__(
baseline_quantile: float = 0.5,
compared_quantile: float = 0.5,
) -> ConvergenceComparator:
return WinRateComparator(
return WinRateConvergenceCurveComparator(
baseline_curve=baseline_curve,
compared_curve=compared_curve,
baseline_quantile=baseline_quantile,
compared_quantile=compared_quantile,
name='win_rate',
name='convergence_curve_win_rate',
comparison_mode=self.comparison_mode,
)


Expand Down Expand Up @@ -1102,27 +1090,6 @@ def __call__(
)


class StandardizedWinRateConvergenceCurveComparatorFactory(
ConvergenceComparatorFactory
):
"""Factory class for StandardizedWinRateConvergenceCurveComparator."""

def __call__(
self,
baseline_curve: ConvergenceCurve,
compared_curve: ConvergenceCurve,
baseline_quantile: float = 0.5,
compared_quantile: float = 0.5,
) -> ConvergenceComparator:
return StandardizedWinRateConvergenceCurveComparator(
baseline_curve=baseline_curve,
compared_curve=compared_curve,
baseline_quantile=baseline_quantile,
compared_quantile=compared_quantile,
name='standardized_win_rate',
)


def build_convergence_curve(
baseline_curve: Sequence[float], compared_curve: Sequence[float]
) -> List[float]:
Expand Down
Loading

0 comments on commit 6da1243

Please sign in to comment.