Skip to content

Commit

Permalink
Fix NormalizeLabels
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 663395588
  • Loading branch information
xingyousong authored and copybara-github committed Aug 15, 2024
1 parent bdd4b1f commit cfbd1bf
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
31 changes: 24 additions & 7 deletions vizier/_src/algorithms/designers/gp/output_warpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import abc
import copy
from typing import Callable, Optional, Sequence
from typing import Callable, Optional, Sequence, Tuple

import attr
import attrs
Expand Down Expand Up @@ -519,11 +519,21 @@ def unwarp(self, labels_arr: types.Array) -> types.Array:
)


@attr.define
class NormalizeLabels(OutputWarper):
"""Normalizes the finite label values, leaving the NaNs & infinities out."""

target_interval: Tuple[float, float] = attr.ib(default=(0.0, 1.0))

def __attrs_post_init__(self):
if self.target_interval[0] > self.target_interval[1]:
raise ValueError(f'Bounds {self.target_interval} is invalid.')

def warp(self, labels_arr: types.Array) -> types.Array:
"""Normalizes the finite label values to bring them between 0 and 1.
"""Normalizes the finite label values to bring them within target_interval.
If all finite labels are equal, they get mapped to target_interval's
midpoint.
Args:
labels_arr: (num_points, 1) shaped array of labels.
Expand All @@ -534,13 +544,20 @@ def warp(self, labels_arr: types.Array) -> types.Array:
labels_arr = _validate_labels(labels_arr)
if np.isnan(labels_arr).all():
raise ValueError('Labels need to have at least one non-NaN entry.')
if np.nanmax(labels_arr) == np.nanmax(labels_arr):
return labels_arr

labels_finite_ind = np.isfinite(labels_arr)
labels_arr_finite = labels_arr[labels_finite_ind]
labels_arr_finite_normalized = (
labels_arr_finite - np.nanmin(labels_arr_finite)
) / (np.nanmax(labels_arr_finite) - np.nanmin(labels_arr_finite))

source_interval = (np.min(labels_arr_finite), np.max(labels_arr_finite))

if source_interval[0] == source_interval[1]:
midpoint = (self.target_interval[0] + self.target_interval[1]) / 2
labels_arr[labels_finite_ind] = midpoint
return labels_arr

labels_arr_finite_normalized = np.interp(
labels_arr_finite, source_interval, self.target_interval
)
labels_arr[labels_finite_ind] = labels_arr_finite_normalized
return labels_arr

Expand Down
11 changes: 9 additions & 2 deletions vizier/_src/algorithms/designers/gp/output_warpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,20 @@ def test_known_arrays(self):

class NormalizeLabelsTest(_OutputWarperTestCase):

def setUp(self):
super().setUp()
self.labels_arr = np.asarray([10.0, 15.0, 20.0])[:, np.newaxis]

@property
def warper(self) -> OutputWarper:
return output_warpers.NormalizeLabels()

def test_known_arrays(self):
# TODO: Add a couple of parameterized test cases.
self.skipTest('No test cases provided')
actual = self.warper.warp(self.labels_arr)
expected = np.asarray([0.0, 0.5, 1.0])[:, np.newaxis]
np.testing.assert_allclose(
actual, expected, err_msg=f'actual: {actual.tolist()}'
)


class DetectOutliersTest(_OutputWarperTestCase):
Expand Down

0 comments on commit cfbd1bf

Please sign in to comment.