diff --git a/vizier/_src/algorithms/designers/gp/output_warpers.py b/vizier/_src/algorithms/designers/gp/output_warpers.py index 582a8fce3..144dcb2de 100644 --- a/vizier/_src/algorithms/designers/gp/output_warpers.py +++ b/vizier/_src/algorithms/designers/gp/output_warpers.py @@ -29,17 +29,26 @@ from tensorflow_probability.substrates import jax as tfp -def _validate_and_deepcopy(labels_arr: chex.Array) -> chex.Array: +def _validate_labels( + labels_arr: chex.Array, warping: bool = True +) -> chex.Array: """Checks and modifies the shape and values of the labels.""" labels_arr = labels_arr.astype(float) - labels_arr_copy = copy.deepcopy(labels_arr) if not (labels_arr.ndim == 2 and labels_arr.shape[-1] == 1): raise ValueError('Labels need to be an array of shape (num_points, 1).') if np.isposinf(labels_arr).any(): raise ValueError('Inifinity metric value is not valid.') if np.isneginf(labels_arr).any(): - labels_arr_copy[np.isneginf(labels_arr)] = np.nan - return labels_arr_copy + labels_arr[np.isneginf(labels_arr)] = np.nan + if ( + np.unique(labels_arr[np.isfinite(labels_arr).flatten(), :]).size <= 1 + and np.isnan(labels_arr).sum() == 0 + ) and warping: + raise ValueError( + 'Labels need to include at least two finite unique value in the absence' + ' of infeaible points.' + ) + return labels_arr class OutputWarper(abc.ABC): @@ -103,7 +112,9 @@ def warp(self, labels_arr: chex.Array) -> chex.Array: Returns: (num_points, 1) shaped array of warped labels. """ - labels_arr = _validate_and_deepcopy(labels_arr) + labels_arr = copy.deepcopy(labels_arr) + if np.isneginf(labels_arr).any(): + labels_arr[np.isneginf(labels_arr)] = np.nan if np.isfinite(labels_arr).all() and len( np.unique(labels_arr).flatten()) == 1: return np.zeros(labels_arr.shape) @@ -127,7 +138,7 @@ def unwarp(self, labels_arr: chex.Array) -> chex.Array: Returns: (num_points, 1) shaped array of unwarped labels. """ - labels_arr = _validate_and_deepcopy(labels_arr) + labels_arr = copy.deepcopy(labels_arr) if ( np.isfinite(labels_arr).all() and len(np.unique(labels_arr).flatten()) == 1 @@ -195,6 +206,11 @@ class HalfRankComponent(OutputWarper): untouched. """ + _median: Optional[float] = attr.field(default=None) + _stddev: Optional[float] = attr.field(default=None) + _dedup_median_index: Optional[int] = attr.field(default=None) + _unique_labels: Optional[chex.Array] = attr.field(default=None) + def _estimate_std_of_good_half( self, unique_labels: chex.Array, threshold: float ) -> float: @@ -221,16 +237,20 @@ def _estimate_std_of_good_half( def warp(self, labels_arr: chex.Array) -> chex.Array: """See base class.""" - labels_arr = _validate_and_deepcopy(labels_arr) + labels_arr = _validate_labels(labels_arr) if labels_arr.size == 1: return labels_arr labels_arr = labels_arr.flatten() # Compute median, unique labels, and ranks. median = np.nanmedian(labels_arr) + self._median = median + self._stddev = np.nanstd(labels_arr) unique_labels = np.unique(labels_arr[np.isfinite(labels_arr)]) + self._unique_labels = unique_labels ranks = stats.rankdata(labels_arr, method='dense') # nans ranked last. dedup_median_index = unique_labels.searchsorted(median, 'left') + self._dedup_median_index = dedup_median_index denominator = dedup_median_index + (unique_labels[dedup_median_index] == median) * .5 estimated_std = self._estimate_std_of_good_half(unique_labels, median) @@ -248,10 +268,43 @@ def warp(self, labels_arr: chex.Array) -> chex.Array: return np.reshape(labels_arr, [-1, 1]) def unwarp(self, labels_arr: chex.Array) -> chex.Array: - raise NotImplementedError( - 'unwarp method for HalfRankComponent is not implemented yet.' + labels_arr = _validate_labels(labels_arr, warping=False) + if np.isnan(labels_arr).any(): + raise ValueError('Array passed to unwarp cannot include nans.') + if self._dedup_median_index == 0: + return self._median + self._stddev * labels_arr + labels_arr[labels_arr >= 0.0] = ( + self._median + self._stddev * labels_arr[labels_arr >= 0.0] + ) + rank_bad = np.array( + [ + 2 * stats.norm.cdf(y) * (self._dedup_median_index + 0.5) - 0.5 + for y in labels_arr[labels_arr < 0.0] + ] + ) + if (rank_bad < -0.5).any() or ( + rank_bad > 1.0001 * self._dedup_median_index + ).any(): + raise ValueError('Rank needs to be within [-0.5, 1.0001 * median-index].') + labels_bad = np.ones(labels_arr[labels_arr < 0.0].shape) + scale = self._stddev + self._median - np.min(self._unique_labels) + if scale < 0.0: + raise ValueError('Scale needs to be non-negative.') + r_ints, r_fracs = divmod(rank_bad[rank_bad >= 0.0], 1) + labels_bad[rank_bad >= 0.0] = np.array( + [ + self._unique_labels(int(r_int)) * (1 - r_frac) + + (self._unique_labels(int(r_int) + 1) * r_frac) + for r_int, r_frac in zip(r_ints, r_fracs) + ] + ) + labels_bad[rank_bad < 0.0] = ( + np.min(self._unique_labels) + scale * rank_bad[rank_bad < 0.0] ) + labels_arr[labels_arr < 0.0] = labels_bad + return labels_arr + @attr.define class LogWarperComponent(OutputWarper): @@ -267,7 +320,7 @@ class LogWarperComponent(OutputWarper): def warp(self, labels_arr: chex.Array) -> chex.Array: """See base class.""" - labels_arr = _validate_and_deepcopy(labels_arr) + labels_arr = _validate_labels(labels_arr) self._labels_min = np.nanmin(labels_arr) self._labels_max = np.nanmax(labels_arr) labels_arr = labels_arr.flatten() @@ -303,7 +356,7 @@ class InfeasibleWarperComponent(OutputWarper): """Warps the infeasible/nan value to feasible/finite values.""" def warp(self, labels_arr: chex.Array) -> chex.Array: - labels_arr = _validate_and_deepcopy(labels_arr) + labels_arr = _validate_labels(labels_arr) labels_arr = labels_arr.flatten() labels_range = np.nanmax(labels_arr) - np.nanmin(labels_arr) warped_bad_value = np.nanmin(labels_arr) - (0.5 * labels_range + 1) @@ -328,7 +381,7 @@ def warp(self, labels_arr: chex.Array) -> chex.Array: Returns: (num_points, 1) shaped array of standardize labels. """ - labels_arr = _validate_and_deepcopy(labels_arr) + labels_arr = _validate_labels(labels_arr) if np.isnan(labels_arr).all(): raise ValueError('Labels need to have at least one non-NaN entry.') labels_finite_ind = np.isfinite(labels_arr) @@ -360,7 +413,8 @@ def warp(self, labels_arr: chex.Array) -> chex.Array: Returns: (num_points, 1) shaped array of normalized labels. """ - labels_arr = _validate_and_deepcopy(labels_arr) + labels_arr = _validate_labels(labels_arr) + if np.isnan(labels_arr).all(): raise ValueError('Labels need to have at least one non-NaN entry.') if np.nanmax(labels_arr) == np.nanmax(labels_arr): @@ -451,7 +505,7 @@ def _estimate_variance(self, labels_arr: chex.Array) -> float: (4 * num_points))**2) def warp(self, labels_arr: chex.Array) -> chex.Array: - labels_arr = _validate_and_deepcopy(labels_arr) + labels_arr = _validate_labels(labels_arr) labels_finite_ind = np.isfinite(labels_arr) labels_arr_finite = labels_arr[labels_finite_ind] labels_median = np.median(labels_arr_finite) @@ -496,7 +550,7 @@ def __init__( self.use_rank = use_rank def warp(self, labels_arr: chex.Array) -> chex.Array: - labels_arr = _validate_and_deepcopy(labels_arr) + labels_arr = _validate_labels(labels_arr) labels_arr = np.asarray(labels_arr, dtype=np.float64) labels_arr_flattened = labels_arr.flatten() if self.use_rank: diff --git a/vizier/_src/algorithms/designers/gp/output_warpers_test.py b/vizier/_src/algorithms/designers/gp/output_warpers_test.py index 8e1725c0a..faabdadf0 100644 --- a/vizier/_src/algorithms/designers/gp/output_warpers_test.py +++ b/vizier/_src/algorithms/designers/gp/output_warpers_test.py @@ -102,15 +102,6 @@ def warper(self) -> OutputWarper: def always_maps_to_finite(self) -> bool: return True - def test_all_nonfinite_labels(self): - labels_infeaible = np.array([[-np.inf], [np.nan], [np.nan], [-np.inf]]) - self.assertTrue( - ( - self.warper.warp(labels_infeaible) - == -1 * np.ones(shape=labels_infeaible.shape).flatten() - ).all() - ) - @parameterized.parameters([ dict(labels=np.zeros(shape=(5, 1))), dict(labels=np.ones(shape=(5, 1))), @@ -376,5 +367,20 @@ def test_known_arrays(self): # TODO: Add a couple of parameterized test cases. self.skipTest('No test cases provided') + +class OutputWarperPipelineTest(absltest.TestCase): + """Tests the default outpur warper edge cases.""" + + def test_all_nonfinite_labels(self): + warper = output_warpers.OutputWarperPipeline() + labels_infeaible = np.array([[-np.inf], [np.nan], [np.nan], [-np.inf]]) + self.assertTrue( + ( + warper.warp(labels_infeaible) + == -1 * np.ones(shape=labels_infeaible.shape).flatten() + ).all() + ) + + if __name__ == '__main__': absltest.main()