diff --git a/ax/modelbridge/transforms/relativize.py b/ax/modelbridge/transforms/relativize.py index e9389afbf37..ec763a0d5a6 100644 --- a/ax/modelbridge/transforms/relativize.py +++ b/ax/modelbridge/transforms/relativize.py @@ -168,15 +168,10 @@ def _rel_op_on_observations( self.modelbridge.status_quo_data_by_trial, self.MISSING_STATUS_QUO_ERROR ) - missing_index = any(obs.features.trial_index is None for obs in observations) - default_trial_idx: Optional[int] = None - if missing_index: - if len(sq_data_by_trial) == 1: - default_trial_idx = next(iter(sq_data_by_trial)) - else: - raise ValueError( - "Observations contain missing trial index that can't be inferred." - ) + # use latest index of latest observed trial by default + # to handle pending trials, which may not have a trial_index + # if TrialAsTask was not used to generate the trial. + default_trial_idx: int = max(sq_data_by_trial.keys()) def _get_relative_data_from_obs( obs: Observation, diff --git a/ax/modelbridge/transforms/tests/test_relativize_transform.py b/ax/modelbridge/transforms/tests/test_relativize_transform.py index 1d65cf9cc20..21f3d888a5c 100644 --- a/ax/modelbridge/transforms/tests/test_relativize_transform.py +++ b/ax/modelbridge/transforms/tests/test_relativize_transform.py @@ -7,7 +7,7 @@ from copy import deepcopy from typing import List, Tuple -from unittest.mock import Mock, patch, PropertyMock +from unittest.mock import Mock import numpy as np from ax.core import BatchTrial @@ -159,21 +159,6 @@ def test_relativize_transform_requires_a_modelbridge_to_have_status_quo_data( observations[0].features.trial_index = 999 self.assertRaises(ValueError, tf.transform_observations, observations) - # When observation has missing trial_index and - # modelbridge.status_quo_data_by_trial has more than one trial, - # raise exception - observations[0].features.trial_index = None - with patch.object( - type(modelbridge), "status_quo_data_by_trial", new_callable=PropertyMock - ) as mock_sq_dict: - # Making modelbridge.status_quo_data_by_trial contains 2 trials - mock_sq_dict.return_value = {0: Mock(), 1: Mock()} - with self.assertRaisesRegex( - ValueError, - "Observations contain missing trial index that can't be inferred.", - ): - tf.transform_observations(observations) - def test_relativize_transform_observations(self) -> None: def _check_transform_observations( tf: Transform, @@ -257,8 +242,18 @@ def _check_transform_observations( observations=observations, expected_mean_and_covar=expected_mean_and_covar, ) - # transform should still work when trial_index is None and - # there is only one sq in modelbridge + # transform should still work when trial_index is None + modelbridge = Mock( + status_quo=Mock( + data=obs_data[0], features=obs_features[0], arm_name=arm_names[0] + ), + status_quo_data_by_trial={0: obs_data[1], 1: obs_data[0]}, + ) + tf = relativize_cls( + search_space=None, + observations=observations, + modelbridge=modelbridge, + ) for obs in observations: obs.features.trial_index = None _check_transform_observations( diff --git a/ax/modelbridge/transforms/tests/test_time_as_feature_transform.py b/ax/modelbridge/transforms/tests/test_time_as_feature_transform.py new file mode 100644 index 00000000000..d8b46020e71 --- /dev/null +++ b/ax/modelbridge/transforms/tests/test_time_as_feature_transform.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from copy import deepcopy +from unittest import mock + +import numpy as np +from ax.core.observation import Observation, ObservationData, ObservationFeatures +from ax.core.parameter import ParameterType, RangeParameter +from ax.core.search_space import SearchSpace +from ax.exceptions.core import UnsupportedError +from ax.modelbridge.transforms.time_as_feature import TimeAsFeature +from ax.utils.common.testutils import TestCase +from ax.utils.common.timeutils import unixtime_to_pandas_ts +from ax.utils.common.typeutils import checked_cast +from ax.utils.testing.core_stubs import get_robust_search_space + + +class TimeAsFeatureTransformTest(TestCase): + def setUp(self) -> None: + super().setUp() + self.search_space = SearchSpace( + parameters=[ + RangeParameter( + "x", lower=1, upper=4, parameter_type=ParameterType.FLOAT + ) + ] + ) + self.training_feats = [ + ObservationFeatures( + {"x": i + 1}, + trial_index=i, + start_time=unixtime_to_pandas_ts(float(i)), + end_time=unixtime_to_pandas_ts(float(i + 1 + i)), + ) + for i in range(4) + ] + self.training_obs = [ + Observation( + data=ObservationData( + metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) + ), + features=obsf, + ) + for obsf in self.training_feats + ] + time_patcher = mock.patch( + "ax.modelbridge.transforms.time_as_feature.time", return_value=5.0 + ) + self.time_patcher = time_patcher.start() + self.addCleanup(time_patcher.stop) + self.t = TimeAsFeature( + search_space=self.search_space, + observations=self.training_obs, + ) + + def test_init(self) -> None: + self.assertEqual(self.t.current_time, 5.0) + self.assertEqual(self.t.min_duration, 1.0) + self.assertEqual(self.t.max_duration, 4.0) + self.assertEqual(self.t.duration_range, 3.0) + self.assertEqual(self.t.min_start_time, 0.0) + self.assertEqual(self.t.max_start_time, 3.0) + + # Test validation + obsf = ObservationFeatures({"x": 2}) + obs = Observation( + data=ObservationData([], np.array([]), np.empty((0, 0))), features=obsf + ) + msg = ( + "Unable to use TimeAsFeature since not all observations have " + "start time specified." + ) + with self.assertRaisesRegex(ValueError, msg): + TimeAsFeature( + search_space=self.search_space, + observations=self.training_obs + [obs], + ) + + t2 = TimeAsFeature( + search_space=self.search_space, + observations=self.training_obs[:1], + ) + self.assertEqual(t2.duration_range, 1.0) + + def test_TransformObservationFeatures(self) -> None: + obs_ft1 = deepcopy(self.training_feats) + obs_ft_trans1 = deepcopy(self.training_feats) + for i, obs in enumerate(obs_ft_trans1): + obs.parameters.update({"start_time": float(i), "duration": 1 / 3 * i}) + obs_ft1 = self.t.transform_observation_features(obs_ft1) + self.assertEqual(obs_ft1, obs_ft_trans1) + obs_ft1 = self.t.untransform_observation_features(obs_ft1) + self.assertEqual(obs_ft1, self.training_feats) + # test transforming observation features that do not have + # start_time/end_time + obsf = [ObservationFeatures({"x": 2.5})] + obsf_trans = self.t.transform_observation_features(obsf) + self.assertEqual( + obsf_trans[0], + ObservationFeatures({"x": 2.5, "duration": 0.5, "start_time": 5.0}), + ) + + def test_TransformSearchSpace(self) -> None: + ss2 = deepcopy(self.search_space) + ss2 = self.t.transform_search_space(ss2) + self.assertEqual(set(ss2.parameters.keys()), {"x", "start_time", "duration"}) + p = checked_cast(RangeParameter, ss2.parameters["start_time"]) + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 0.0) + self.assertEqual(p.upper, 3.0) + p = checked_cast(RangeParameter, ss2.parameters["duration"]) + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 0.0) + self.assertEqual(p.upper, 1.0) + + def test_w_robust_search_space(self) -> None: + rss = get_robust_search_space() + # Raises an error in __init__. + with self.assertRaisesRegex(UnsupportedError, "transform is not supported"): + TimeAsFeature( + search_space=rss, + observations=[], + ) diff --git a/ax/modelbridge/transforms/time_as_feature.py b/ax/modelbridge/transforms/time_as_feature.py new file mode 100644 index 00000000000..a4804d3d790 --- /dev/null +++ b/ax/modelbridge/transforms/time_as_feature.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from logging import Logger +from time import time +from typing import List, Optional, TYPE_CHECKING + +import pandas as pd + +from ax.core.observation import Observation, ObservationFeatures +from ax.core.parameter import ParameterType, RangeParameter +from ax.core.search_space import RobustSearchSpace, SearchSpace +from ax.exceptions.core import UnsupportedError +from ax.modelbridge.transforms.base import Transform +from ax.models.types import TConfig +from ax.utils.common.logger import get_logger +from ax.utils.common.timeutils import unixtime_to_pandas_ts +from ax.utils.common.typeutils import checked_cast, not_none + +if TYPE_CHECKING: + # import as module to make sphinx-autodoc-typehints happy + from ax import modelbridge as modelbridge_module # noqa F401 + + +logger: Logger = get_logger(__name__) + + +class TimeAsFeature(Transform): + """Convert start time and duration into features that can be used for modeling. + + If no end_time is available, the current time is used. + + Duration is normalized to the unit cube. + + Transform is done in-place. + + TODO: revise this when better support for non-tunable features is added. + """ + + def __init__( + self, + search_space: Optional[SearchSpace] = None, + observations: Optional[List[Observation]] = None, + modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + config: Optional[TConfig] = None, + ) -> None: + assert observations is not None, "TimeAsFeature requires observations" + if isinstance(search_space, RobustSearchSpace): + raise UnsupportedError( + "TimeAsFeature transform is not supported for RobustSearchSpace." + ) + self.min_start_time: float = float("inf") + self.max_start_time: float = float("-inf") + self.min_duration: float = float("inf") + self.max_duration: float = float("-inf") + self.current_time: float = time() + for obs in observations: + obsf = obs.features + if obsf.start_time is None: + raise ValueError( + "Unable to use TimeAsFeature since not all observations have " + "start time specified." + ) + start_time = not_none(obsf.start_time).timestamp() + self.min_start_time = min(self.min_start_time, start_time) + self.max_start_time = max(self.max_start_time, start_time) + duration = self._get_duration(start_time=start_time, end_time=obsf.end_time) + self.min_duration = min(self.min_duration, duration) + self.max_duration = max(self.max_duration, duration) + self.duration_range: float = self.max_duration - self.min_duration + if self.duration_range == 0: + # no need to case-distinguish during normalization + self.duration_range = 1.0 + + def _get_duration( + self, start_time: float, end_time: Optional[pd.Timestamp] + ) -> float: + return ( + self.current_time if end_time is None else end_time.timestamp() + ) - start_time + + def transform_observation_features( + self, observation_features: List[ObservationFeatures] + ) -> List[ObservationFeatures]: + for obsf in observation_features: + if obsf.start_time is not None: + start_time = obsf.start_time.timestamp() + obsf.parameters["start_time"] = start_time + duration = self._get_duration( + start_time=start_time, end_time=obsf.end_time + ) + # normalize duration to the unit cube + obsf.parameters["duration"] = ( + duration - self.min_duration + ) / self.duration_range + else: + # start time can be None for pending arms that generated + # with a model that did not use the TimeAsFeature transform. + # In that case, assume the arm is going to be evaluated at the + # current time, and that the duration is the midpoint of the + # range. + obsf.parameters["start_time"] = self.current_time + obsf.parameters["duration"] = 0.5 + return observation_features + + def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: + for p_name in ("start_time", "duration"): + if p_name in search_space.parameters: + raise ValueError( + f"Parameter name {p_name} is reserved when using " + "TimeAsFeature transform, but is part of the provided " + "search space. Please choose a different name for " + "this parameter." + ) + param = RangeParameter( + name="start_time", + parameter_type=ParameterType.FLOAT, + lower=self.min_start_time, + upper=self.max_start_time, + ) + search_space.add_parameter(param) + param = RangeParameter( + name="duration", + parameter_type=ParameterType.FLOAT, + # duration is normalized to [0,1] + lower=0.0, + upper=1.0, + ) + search_space.add_parameter(param) + return search_space + + def untransform_observation_features( + self, observation_features: List[ObservationFeatures] + ) -> List[ObservationFeatures]: + for obsf in observation_features: + start_time = checked_cast(float, obsf.parameters.pop("start_time")) + obsf.start_time = unixtime_to_pandas_ts(start_time) + obsf.end_time = unixtime_to_pandas_ts( + checked_cast(float, obsf.parameters.pop("duration")) + * self.duration_range + + self.min_duration + + start_time + ) + return observation_features diff --git a/ax/storage/transform_registry.py b/ax/storage/transform_registry.py index e81801ac7c0..9ea4b20796e 100644 --- a/ax/storage/transform_registry.py +++ b/ax/storage/transform_registry.py @@ -40,6 +40,7 @@ from ax.modelbridge.transforms.standardize_y import StandardizeY from ax.modelbridge.transforms.stratified_standardize_y import StratifiedStandardizeY from ax.modelbridge.transforms.task_encode import TaskEncode +from ax.modelbridge.transforms.time_as_feature import TimeAsFeature from ax.modelbridge.transforms.trial_as_task import TrialAsTask from ax.modelbridge.transforms.unit_x import UnitX from ax.modelbridge.transforms.winsorize import Winsorize @@ -90,6 +91,7 @@ Relativize: 24, RelativizeWithConstantControl: 25, MergeRepeatedMeasurements: 26, + TimeAsFeature: 27, } """ diff --git a/ax/utils/common/timeutils.py b/ax/utils/common/timeutils.py index 769ecef1555..62f33180ab6 100644 --- a/ax/utils/common/timeutils.py +++ b/ax/utils/common/timeutils.py @@ -52,3 +52,8 @@ def timestamps_in_range( while curr <= end: yield curr curr += delta + + +def unixtime_to_pandas_ts(ts: float) -> pd.Timestamp: + """Convert float unixtime into pandas timestamp (UTC).""" + return pd.to_datetime(ts, unit="s") diff --git a/sphinx/source/modelbridge.rst b/sphinx/source/modelbridge.rst index 234d2310a4a..96f8367d9f7 100644 --- a/sphinx/source/modelbridge.rst +++ b/sphinx/source/modelbridge.rst @@ -369,6 +369,14 @@ Transforms :undoc-members: :show-inheritance: +`ax.modelbridge.transforms.time\_as\_feature` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: ax.modelbridge.transforms.time_as_feature + :members: + :undoc-members: + :show-inheritance: + `ax.modelbridge.transforms.trial\_as\_task` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~