Skip to content

Commit

Permalink
Add NormalizedSimpleRegert and WinRateSimpleRegret convergence curve …
Browse files Browse the repository at this point in the history
…comparators.

PiperOrigin-RevId: 599931886
  • Loading branch information
belenkil authored and copybara-github committed Jan 19, 2024
1 parent 5e3e074 commit fad0956
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 41 deletions.
114 changes: 98 additions & 16 deletions vizier/_src/benchmarks/analyzers/convergence_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,9 +642,10 @@ def _standardize_curves(
Note: This is an helper function that the class implementing this interface
can choose to use or not.
1. Align xs and keeping each ys.
2. Apply quantiles and impute NaN.
3. Remove values where xs < xs_cutoff.
1. Align xs and keep each ys.
2. Convert curves to INCREASING.
3. Apply quantiles and impute NaN.
4. Remove values where xs < xs_cutoff.
Args:
xs_cutoff: The xs value before which values are ignored.
Expand Down Expand Up @@ -825,8 +826,8 @@ def score(self) -> float:


@attr.define
class SimpleConvergenceCurveComparator(ConvergenceComparator):
"""Comparator method based on simple comparison.
class StandardizedWinRateConvergenceCurveComparator(ConvergenceComparator):
"""Comparator method based on win rate on the standardized curves.
Attributes:
burn_cutoff: The cutoff below which values not included in score.
Expand All @@ -835,7 +836,7 @@ class SimpleConvergenceCurveComparator(ConvergenceComparator):
_xs_cutoff: Optional[float] = None

def score(self) -> float:
"""Computes the simple convergence score.
"""Computes the standardized win-rate convergence score.
The score is the percentage of indices (after the burn cutoff) for which the
interpolated values of 'compared_curve' are better than the interpolated
Expand All @@ -860,6 +861,13 @@ def curve(self) -> ConvergenceCurve:
class PercentageBetterConvergenceCurveComparator(ConvergenceComparator):
"""Comparator method based on percentage better.
PercentageBetter is the average percentage of steps that one curve is better
than the other.
For example, assuming a study with 100 trials, a score of 0.07 means that on
average for each 'baseline' trial the 'compared' convergence curve has already
reached that value 7 steps before.
Attributes:
burn_cutoff: The cutoff below which values not included in score.
"""
Expand All @@ -869,10 +877,7 @@ class PercentageBetterConvergenceCurveComparator(ConvergenceComparator):
def _compute_directional_score(
self, baseline: np.ndarray, compared: np.ndarray
) -> float:
"""Compute the percentage better score.
The score is the average percentage steps that 'compared' is better than
'baseline'.
"""Compute the percentage better score of 'compared' vs. 'baseline'.
Note that: sum_i sum_j 1{c_j > b_i} = sum_j sum_i {b_i < c_j}. Therefore, we
can either iterate over 'compared' and count the number of steps that
Expand All @@ -882,7 +887,7 @@ def _compute_directional_score(
Implementation
--------------
1. For each index of `baseline`:
- Finds the smallest index of 'comared' that is better.
- Finds the smallest index of 'compared' that is better.
- Compute the percentage of `compared` steps that are better.
2. Average the percentages across all 'baseline' indices.
Expand Down Expand Up @@ -958,6 +963,81 @@ def curve(self) -> ConvergenceCurve:
)


class WinRateSimpleRegretComparator(ConvergenceComparator):
"""Comparator method based on win-rate simple regert."""

def score(self):
"""Computes the normalized simple regert score."""
baseline_ys, compared_ys = self._standardize_curves()
print('compared_ys:', compared_ys)
print('baseline_ys:', baseline_ys)
return float(compared_ys[-1] > baseline_ys[-1])

def curve(self) -> ConvergenceCurve:
"""Returns a score curve for each xs."""
raise NotImplementedError('Curve not yet implemented.')


class NormalizedSimpleRegretComparator(ConvergenceComparator):
"""Comparator method based on normalized simple regert.
The simple regret gain ('compared' - 'baseline') is normalized by the
'baseline' absolute simple regret and then truncated.
"""

min_value: float = -0.5
max_value: float = 1.0
eps: float = 0.0001

def score(self):
"""Computes the normalized simple regert score."""
baseline_ys, compared_ys = self._standardize_curves()
d = (compared_ys[-1] - baseline_ys[-1]) / (abs(baseline_ys[-1]) + self.eps)
return min(max(d, self.min_value), self.max_value)

def curve(self) -> ConvergenceCurve:
"""Returns a score curve for each xs."""
raise NotImplementedError('Curve not yet implemented.')


class WinRateSimpleRegretComparatorFactory(ConvergenceComparatorFactory):
"""Factory class for WinRateSimpleRegretComparator."""

def __call__(
self,
baseline_curve: ConvergenceCurve,
compared_curve: ConvergenceCurve,
baseline_quantile: float = 0.5,
compared_quantile: float = 0.5,
) -> ConvergenceComparator:
return WinRateSimpleRegretComparator(
baseline_curve=baseline_curve,
compared_curve=compared_curve,
baseline_quantile=baseline_quantile,
compared_quantile=compared_quantile,
name='win_rate_simple_regret',
)


class NormalizedSimpleRegretComparatorFactory(ConvergenceComparatorFactory):
"""Factory class for NormalizedSimpleRegretComparator."""

def __call__(
self,
baseline_curve: ConvergenceCurve,
compared_curve: ConvergenceCurve,
baseline_quantile: float = 0.5,
compared_quantile: float = 0.5,
) -> ConvergenceComparator:
return NormalizedSimpleRegretComparator(
baseline_curve=baseline_curve,
compared_curve=compared_curve,
baseline_quantile=baseline_quantile,
compared_quantile=compared_quantile,
name='normalized_simple_regret',
)


class WinRateComparatorFactory(ConvergenceComparatorFactory):
"""Factory class for WinRateComparatorFactory."""

Expand Down Expand Up @@ -1015,12 +1095,14 @@ def __call__(
compared_curve=compared_curve,
baseline_quantile=baseline_quantile,
compared_quantile=compared_quantile,
name='%_better',
name='pct_better',
)


class SimpleConvergenceCurveComparatorFactory(ConvergenceComparatorFactory):
"""Factory class for SimpleConvergenceCurveCompartor."""
class StandardizedWinRateConvergenceCurveComparatorFactory(
ConvergenceComparatorFactory
):
"""Factory class for StandardizedWinRateConvergenceCurveComparator."""

def __call__(
self,
Expand All @@ -1029,12 +1111,12 @@ def __call__(
baseline_quantile: float = 0.5,
compared_quantile: float = 0.5,
) -> ConvergenceComparator:
return SimpleConvergenceCurveComparator(
return StandardizedWinRateConvergenceCurveComparator(
baseline_curve=baseline_curve,
compared_curve=compared_curve,
baseline_quantile=baseline_quantile,
compared_quantile=compared_quantile,
name='simple',
name='standardized_win_rate',
)


Expand Down
Loading

0 comments on commit fad0956

Please sign in to comment.