diff --git a/ax/modelbridge/tests/test_derelativize_transform.py b/ax/modelbridge/tests/test_derelativize_transform.py index 270e559bd4a..3874e4a9cbe 100644 --- a/ax/modelbridge/tests/test_derelativize_transform.py +++ b/ax/modelbridge/tests/test_derelativize_transform.py @@ -6,7 +6,7 @@ from copy import deepcopy from unittest import mock -from unittest.mock import patch +from unittest.mock import Mock, patch import numpy as np from ax.core.data import Data @@ -75,20 +75,13 @@ def setUp(self) -> None: ] ), ) - # pyre-fixme[3]: Return type must be annotated. def test_DerelativizeTransform( self, - # pyre-fixme[2]: Parameter must be annotated. - mock_predict, - # pyre-fixme[2]: Parameter must be annotated. - mock_fit, - # pyre-fixme[2]: Parameter must be annotated. - mock_observations_from_data, - ): - t = Derelativize( - search_space=None, - observations=[], - ) + mock_predict: Mock, + mock_fit: Mock, + mock_observations_from_data: Mock, + ) -> None: + t = Derelativize(search_space=None, observations=[]) # ModelBridge with in-design status quo search_space = SearchSpace( @@ -167,6 +160,13 @@ def test_DerelativizeTransform( obsf = mock_predict.mock_calls[0][1][1][0] obsf2 = ObservationFeatures(parameters={"x": 2.0, "y": 10.0}) self.assertTrue(obsf == obsf2) + self.assertEqual(mock_predict.call_count, 1) + + # The model should not be used when `use_raw_status_quo` is True + t2 = deepcopy(t) + t2.config["use_raw_status_quo"] = True + t2.transform_optimization_config(deepcopy(oc), g, None) + self.assertEqual(mock_predict.call_count, 1) # Test with relative constraint, out-of-design status quo mock_predict.side_effect = RuntimeError() @@ -215,7 +215,7 @@ def test_DerelativizeTransform( ), ] ) - self.assertEqual(mock_predict.call_count, 2) + self.assertEqual(mock_predict.call_count, 1) # Raises error if predict fails with in-design status quo g = ModelBridge( diff --git a/ax/modelbridge/transforms/derelativize.py b/ax/modelbridge/transforms/derelativize.py index 57b36385ddb..64563733c0d 100644 --- a/ax/modelbridge/transforms/derelativize.py +++ b/ax/modelbridge/transforms/derelativize.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from logging import Logger from typing import List, Optional, TYPE_CHECKING import numpy as np @@ -14,6 +15,7 @@ from ax.modelbridge.base import unwrap_observation_data from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.ivw import ivw_metric_merge +from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import not_none @@ -22,6 +24,9 @@ from ax import modelbridge as modelbridge_module # noqa F401 +logger: Logger = get_logger(__name__) + + class Derelativize(Transform): """Changes relative constraints to not-relative constraints using a plug-in estimate of the status quo value. @@ -59,22 +64,17 @@ def transform_optimization_config( "Optimization config has relative constraint, but model was " "not fit with status quo." ) - try: - f, _ = modelbridge.predict([modelbridge.status_quo.features]) - except Exception: - # Check if it is out-of-design. - if use_raw_sq or not modelbridge.model_space.check_membership( - modelbridge.status_quo.features.parameters - ): - # Out-of-design: use the raw observation - sq_data = ivw_metric_merge( - obsd=not_none(modelbridge.status_quo).data, - conflicting_noiseless="raise", - ) - f, _ = unwrap_observation_data([sq_data]) - else: - # Should have worked. - raise + + sq = not_none(modelbridge.status_quo) + # Only use model predictions if the status quo is in the search space (including + # parameter constraints) and `use_raw_sq` is false. + if not use_raw_sq and modelbridge.model_space.check_membership( + sq.features.parameters + ): + f, _ = modelbridge.predict([sq.features]) + else: + sq_data = ivw_metric_merge(obsd=sq.data, conflicting_noiseless="raise") + f, _ = unwrap_observation_data([sq_data]) # Plug in the status quo value to each relative constraint. for c in optimization_config.all_constraints: