Skip to content

Commit

Permalink
Merge c4cead9 into 883a9f1
Browse files Browse the repository at this point in the history
  • Loading branch information
dme65 authored Dec 1, 2023
2 parents 883a9f1 + c4cead9 commit 7d51249
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 30 deletions.
28 changes: 14 additions & 14 deletions ax/modelbridge/tests/test_derelativize_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from copy import deepcopy
from unittest import mock
from unittest.mock import patch
from unittest.mock import Mock, patch

import numpy as np
from ax.core.data import Data
Expand Down Expand Up @@ -75,20 +75,13 @@ def setUp(self) -> None:
]
),
)
# pyre-fixme[3]: Return type must be annotated.
def test_DerelativizeTransform(
self,
# pyre-fixme[2]: Parameter must be annotated.
mock_predict,
# pyre-fixme[2]: Parameter must be annotated.
mock_fit,
# pyre-fixme[2]: Parameter must be annotated.
mock_observations_from_data,
):
t = Derelativize(
search_space=None,
observations=[],
)
mock_predict: Mock,
mock_fit: Mock,
mock_observations_from_data: Mock,
) -> None:
t = Derelativize(search_space=None, observations=[])

# ModelBridge with in-design status quo
search_space = SearchSpace(
Expand Down Expand Up @@ -167,6 +160,13 @@ def test_DerelativizeTransform(
obsf = mock_predict.mock_calls[0][1][1][0]
obsf2 = ObservationFeatures(parameters={"x": 2.0, "y": 10.0})
self.assertTrue(obsf == obsf2)
self.assertEqual(mock_predict.call_count, 1)

# The model should not be used when `use_raw_status_quo` is True
t2 = deepcopy(t)
t2.config["use_raw_status_quo"] = True
t2.transform_optimization_config(deepcopy(oc), g, None)
self.assertEqual(mock_predict.call_count, 1)

# Test with relative constraint, out-of-design status quo
mock_predict.side_effect = RuntimeError()
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_DerelativizeTransform(
),
]
)
self.assertEqual(mock_predict.call_count, 2)
self.assertEqual(mock_predict.call_count, 1)

# Raises error if predict fails with in-design status quo
g = ModelBridge(
Expand Down
32 changes: 16 additions & 16 deletions ax/modelbridge/transforms/derelativize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from logging import Logger
from typing import List, Optional, TYPE_CHECKING

import numpy as np
Expand All @@ -14,6 +15,7 @@
from ax.modelbridge.base import unwrap_observation_data
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.ivw import ivw_metric_merge
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none


Expand All @@ -22,6 +24,9 @@
from ax import modelbridge as modelbridge_module # noqa F401


logger: Logger = get_logger(__name__)


class Derelativize(Transform):
"""Changes relative constraints to not-relative constraints using a plug-in
estimate of the status quo value.
Expand Down Expand Up @@ -59,22 +64,17 @@ def transform_optimization_config(
"Optimization config has relative constraint, but model was "
"not fit with status quo."
)
try:
f, _ = modelbridge.predict([modelbridge.status_quo.features])
except Exception:
# Check if it is out-of-design.
if use_raw_sq or not modelbridge.model_space.check_membership(
modelbridge.status_quo.features.parameters
):
# Out-of-design: use the raw observation
sq_data = ivw_metric_merge(
obsd=not_none(modelbridge.status_quo).data,
conflicting_noiseless="raise",
)
f, _ = unwrap_observation_data([sq_data])
else:
# Should have worked.
raise

sq = not_none(modelbridge.status_quo)
# Only use model predictions if the status quo is in the search space (including
# parameter constraints) and `use_raw_sq` is false.
if not use_raw_sq and modelbridge.model_space.check_membership(
sq.features.parameters
):
f, _ = modelbridge.predict([sq.features])
else:
sq_data = ivw_metric_merge(obsd=sq.data, conflicting_noiseless="raise")
f, _ = unwrap_observation_data([sq_data])

# Plug in the status quo value to each relative constraint.
for c in optimization_config.all_constraints:
Expand Down

0 comments on commit 7d51249

Please sign in to comment.