Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle ObservationFeatures without trial_index in Relativize #2441

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions ax/modelbridge/transforms/relativize.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,10 @@ def _rel_op_on_observations(
self.modelbridge.status_quo_data_by_trial, self.MISSING_STATUS_QUO_ERROR
)

missing_index = any(obs.features.trial_index is None for obs in observations)
default_trial_idx: Optional[int] = None
if missing_index:
if len(sq_data_by_trial) == 1:
default_trial_idx = next(iter(sq_data_by_trial))
else:
raise ValueError(
"Observations contain missing trial index that can't be inferred."
)
# use latest index of latest observed trial by default
# to handle pending trials, which may not have a trial_index
# if TrialAsTask was not used to generate the trial.
default_trial_idx: int = max(sq_data_by_trial.keys())

def _get_relative_data_from_obs(
obs: Observation,
Expand Down
31 changes: 13 additions & 18 deletions ax/modelbridge/transforms/tests/test_relativize_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from copy import deepcopy
from typing import List, Tuple
from unittest.mock import Mock, patch, PropertyMock
from unittest.mock import Mock

import numpy as np
from ax.core import BatchTrial
Expand Down Expand Up @@ -159,21 +159,6 @@ def test_relativize_transform_requires_a_modelbridge_to_have_status_quo_data(
observations[0].features.trial_index = 999
self.assertRaises(ValueError, tf.transform_observations, observations)

# When observation has missing trial_index and
# modelbridge.status_quo_data_by_trial has more than one trial,
# raise exception
observations[0].features.trial_index = None
with patch.object(
type(modelbridge), "status_quo_data_by_trial", new_callable=PropertyMock
) as mock_sq_dict:
# Making modelbridge.status_quo_data_by_trial contains 2 trials
mock_sq_dict.return_value = {0: Mock(), 1: Mock()}
with self.assertRaisesRegex(
ValueError,
"Observations contain missing trial index that can't be inferred.",
):
tf.transform_observations(observations)

def test_relativize_transform_observations(self) -> None:
def _check_transform_observations(
tf: Transform,
Expand Down Expand Up @@ -257,8 +242,18 @@ def _check_transform_observations(
observations=observations,
expected_mean_and_covar=expected_mean_and_covar,
)
# transform should still work when trial_index is None and
# there is only one sq in modelbridge
# transform should still work when trial_index is None
modelbridge = Mock(
status_quo=Mock(
data=obs_data[0], features=obs_features[0], arm_name=arm_names[0]
),
status_quo_data_by_trial={0: obs_data[1], 1: obs_data[0]},
)
tf = relativize_cls(
search_space=None,
observations=observations,
modelbridge=modelbridge,
)
for obs in observations:
obs.features.trial_index = None
_check_transform_observations(
Expand Down
129 changes: 129 additions & 0 deletions ax/modelbridge/transforms/tests/test_time_as_feature_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from copy import deepcopy
from unittest import mock

import numpy as np
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UnsupportedError
from ax.modelbridge.transforms.time_as_feature import TimeAsFeature
from ax.utils.common.testutils import TestCase
from ax.utils.common.timeutils import unixtime_to_pandas_ts
from ax.utils.common.typeutils import checked_cast
from ax.utils.testing.core_stubs import get_robust_search_space


class TimeAsFeatureTransformTest(TestCase):
def setUp(self) -> None:
super().setUp()
self.search_space = SearchSpace(
parameters=[
RangeParameter(
"x", lower=1, upper=4, parameter_type=ParameterType.FLOAT
)
]
)
self.training_feats = [
ObservationFeatures(
{"x": i + 1},
trial_index=i,
start_time=unixtime_to_pandas_ts(float(i)),
end_time=unixtime_to_pandas_ts(float(i + 1 + i)),
)
for i in range(4)
]
self.training_obs = [
Observation(
data=ObservationData(
metric_names=[], means=np.array([]), covariance=np.empty((0, 0))
),
features=obsf,
)
for obsf in self.training_feats
]
time_patcher = mock.patch(
"ax.modelbridge.transforms.time_as_feature.time", return_value=5.0
)
self.time_patcher = time_patcher.start()
self.addCleanup(time_patcher.stop)
self.t = TimeAsFeature(
search_space=self.search_space,
observations=self.training_obs,
)

def test_init(self) -> None:
self.assertEqual(self.t.current_time, 5.0)
self.assertEqual(self.t.min_duration, 1.0)
self.assertEqual(self.t.max_duration, 4.0)
self.assertEqual(self.t.duration_range, 3.0)
self.assertEqual(self.t.min_start_time, 0.0)
self.assertEqual(self.t.max_start_time, 3.0)

# Test validation
obsf = ObservationFeatures({"x": 2})
obs = Observation(
data=ObservationData([], np.array([]), np.empty((0, 0))), features=obsf
)
msg = (
"Unable to use TimeAsFeature since not all observations have "
"start time specified."
)
with self.assertRaisesRegex(ValueError, msg):
TimeAsFeature(
search_space=self.search_space,
observations=self.training_obs + [obs],
)

t2 = TimeAsFeature(
search_space=self.search_space,
observations=self.training_obs[:1],
)
self.assertEqual(t2.duration_range, 1.0)

def test_TransformObservationFeatures(self) -> None:
obs_ft1 = deepcopy(self.training_feats)
obs_ft_trans1 = deepcopy(self.training_feats)
for i, obs in enumerate(obs_ft_trans1):
obs.parameters.update({"start_time": float(i), "duration": 1 / 3 * i})
obs_ft1 = self.t.transform_observation_features(obs_ft1)
self.assertEqual(obs_ft1, obs_ft_trans1)
obs_ft1 = self.t.untransform_observation_features(obs_ft1)
self.assertEqual(obs_ft1, self.training_feats)
# test transforming observation features that do not have
# start_time/end_time
obsf = [ObservationFeatures({"x": 2.5})]
obsf_trans = self.t.transform_observation_features(obsf)
self.assertEqual(
obsf_trans[0],
ObservationFeatures({"x": 2.5, "duration": 0.5, "start_time": 5.0}),
)

def test_TransformSearchSpace(self) -> None:
ss2 = deepcopy(self.search_space)
ss2 = self.t.transform_search_space(ss2)
self.assertEqual(set(ss2.parameters.keys()), {"x", "start_time", "duration"})
p = checked_cast(RangeParameter, ss2.parameters["start_time"])
self.assertEqual(p.parameter_type, ParameterType.FLOAT)
self.assertEqual(p.lower, 0.0)
self.assertEqual(p.upper, 3.0)
p = checked_cast(RangeParameter, ss2.parameters["duration"])
self.assertEqual(p.parameter_type, ParameterType.FLOAT)
self.assertEqual(p.lower, 0.0)
self.assertEqual(p.upper, 1.0)

def test_w_robust_search_space(self) -> None:
rss = get_robust_search_space()
# Raises an error in __init__.
with self.assertRaisesRegex(UnsupportedError, "transform is not supported"):
TimeAsFeature(
search_space=rss,
observations=[],
)
149 changes: 149 additions & 0 deletions ax/modelbridge/transforms/time_as_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from logging import Logger
from time import time
from typing import List, Optional, TYPE_CHECKING

import pandas as pd

from ax.core.observation import Observation, ObservationFeatures
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import RobustSearchSpace, SearchSpace
from ax.exceptions.core import UnsupportedError
from ax.modelbridge.transforms.base import Transform
from ax.models.types import TConfig
from ax.utils.common.logger import get_logger
from ax.utils.common.timeutils import unixtime_to_pandas_ts
from ax.utils.common.typeutils import checked_cast, not_none

if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401


logger: Logger = get_logger(__name__)


class TimeAsFeature(Transform):
"""Convert start time and duration into features that can be used for modeling.

If no end_time is available, the current time is used.

Duration is normalized to the unit cube.

Transform is done in-place.

TODO: revise this when better support for non-tunable features is added.
"""

def __init__(
self,
search_space: Optional[SearchSpace] = None,
observations: Optional[List[Observation]] = None,
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
config: Optional[TConfig] = None,
) -> None:
assert observations is not None, "TimeAsFeature requires observations"
if isinstance(search_space, RobustSearchSpace):
raise UnsupportedError(
"TimeAsFeature transform is not supported for RobustSearchSpace."
)
self.min_start_time: float = float("inf")
self.max_start_time: float = float("-inf")
self.min_duration: float = float("inf")
self.max_duration: float = float("-inf")
self.current_time: float = time()
for obs in observations:
obsf = obs.features
if obsf.start_time is None:
raise ValueError(
"Unable to use TimeAsFeature since not all observations have "
"start time specified."
)
start_time = not_none(obsf.start_time).timestamp()
self.min_start_time = min(self.min_start_time, start_time)
self.max_start_time = max(self.max_start_time, start_time)
duration = self._get_duration(start_time=start_time, end_time=obsf.end_time)
self.min_duration = min(self.min_duration, duration)
self.max_duration = max(self.max_duration, duration)
self.duration_range: float = self.max_duration - self.min_duration
if self.duration_range == 0:
# no need to case-distinguish during normalization
self.duration_range = 1.0

def _get_duration(
self, start_time: float, end_time: Optional[pd.Timestamp]
) -> float:
return (
self.current_time if end_time is None else end_time.timestamp()
) - start_time

def transform_observation_features(
self, observation_features: List[ObservationFeatures]
) -> List[ObservationFeatures]:
for obsf in observation_features:
if obsf.start_time is not None:
start_time = obsf.start_time.timestamp()
obsf.parameters["start_time"] = start_time
duration = self._get_duration(
start_time=start_time, end_time=obsf.end_time
)
# normalize duration to the unit cube
obsf.parameters["duration"] = (
duration - self.min_duration
) / self.duration_range
else:
# start time can be None for pending arms that generated
# with a model that did not use the TimeAsFeature transform.
# In that case, assume the arm is going to be evaluated at the
# current time, and that the duration is the midpoint of the
# range.
obsf.parameters["start_time"] = self.current_time
obsf.parameters["duration"] = 0.5
return observation_features

def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
for p_name in ("start_time", "duration"):
if p_name in search_space.parameters:
raise ValueError(
f"Parameter name {p_name} is reserved when using "
"TimeAsFeature transform, but is part of the provided "
"search space. Please choose a different name for "
"this parameter."
)
param = RangeParameter(
name="start_time",
parameter_type=ParameterType.FLOAT,
lower=self.min_start_time,
upper=self.max_start_time,
)
search_space.add_parameter(param)
param = RangeParameter(
name="duration",
parameter_type=ParameterType.FLOAT,
# duration is normalized to [0,1]
lower=0.0,
upper=1.0,
)
search_space.add_parameter(param)
return search_space

def untransform_observation_features(
self, observation_features: List[ObservationFeatures]
) -> List[ObservationFeatures]:
for obsf in observation_features:
start_time = checked_cast(float, obsf.parameters.pop("start_time"))
obsf.start_time = unixtime_to_pandas_ts(start_time)
obsf.end_time = unixtime_to_pandas_ts(
checked_cast(float, obsf.parameters.pop("duration"))
* self.duration_range
+ self.min_duration
+ start_time
)
return observation_features
2 changes: 2 additions & 0 deletions ax/storage/transform_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ax.modelbridge.transforms.standardize_y import StandardizeY
from ax.modelbridge.transforms.stratified_standardize_y import StratifiedStandardizeY
from ax.modelbridge.transforms.task_encode import TaskEncode
from ax.modelbridge.transforms.time_as_feature import TimeAsFeature
from ax.modelbridge.transforms.trial_as_task import TrialAsTask
from ax.modelbridge.transforms.unit_x import UnitX
from ax.modelbridge.transforms.winsorize import Winsorize
Expand Down Expand Up @@ -90,6 +91,7 @@
Relativize: 24,
RelativizeWithConstantControl: 25,
MergeRepeatedMeasurements: 26,
TimeAsFeature: 27,
}

"""
Expand Down
5 changes: 5 additions & 0 deletions ax/utils/common/timeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,8 @@ def timestamps_in_range(
while curr <= end:
yield curr
curr += delta


def unixtime_to_pandas_ts(ts: float) -> pd.Timestamp:
"""Convert float unixtime into pandas timestamp (UTC)."""
return pd.to_datetime(ts, unit="s")
8 changes: 8 additions & 0 deletions sphinx/source/modelbridge.rst
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,14 @@ Transforms
:undoc-members:
:show-inheritance:

`ax.modelbridge.transforms.time\_as\_feature`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automodule:: ax.modelbridge.transforms.time_as_feature
:members:
:undoc-members:
:show-inheritance:

`ax.modelbridge.transforms.trial\_as\_task`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down