Skip to content

Commit

Permalink
Allow removal of parameters with different values in RemoveFixed tran…
Browse files Browse the repository at this point in the history
…sform (#2779)

Summary:
Pull Request resolved: #2779

Previously, the `RemoveFixed.transform_observation_features` transform checked that the parameter being removed had the same value as the fixed parameter in the search space used to initialize the transform.
This diff removes this check to allow removal of any observation parameter with the same name. This will allow the transform to operate on observations from two similar but non-identical search spaces.

Reviewed By: susanxia1006

Differential Revision: D63322396

fbshipit-source-id: 2648a1cac171a871229101fcc1d44c4200f58d1d
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Sep 24, 2024
1 parent 549adf0 commit 7a0ffb5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
10 changes: 2 additions & 8 deletions ax/modelbridge/transforms/remove_fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,8 @@ def transform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
for obsf in observation_features:
for p_name, fixed_p in self.fixed_parameters.items():
if p_name in obsf.parameters:
if obsf.parameters[p_name] != fixed_p.value:
raise ValueError(
f"Fixed parameter {p_name} with out of design value: "
f"{obsf.parameters[p_name]} passed to `RemoveFixed`."
)
obsf.parameters.pop(p_name)
for p_name in self.fixed_parameters:
obsf.parameters.pop(p_name, None)
return observation_features

def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
Expand Down
11 changes: 7 additions & 4 deletions ax/modelbridge/transforms/tests/test_remove_fixed_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,15 @@ def test_TransformObservationFeatures(self) -> None:
observation_features = [
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a"})
]
observation_features_invalid = [
observation_features_different = [
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "b"})
]
# Fixed parameter out of design!
with self.assertRaises(ValueError):
self.t.transform_observation_features(observation_features_invalid)
# Fixed parameter is out of design. It will still get removed.
t_obs = self.t.transform_observation_features(observation_features)
t_obs_different = self.t.transform_observation_features(
observation_features_different
)
self.assertEqual(t_obs, t_obs_different)

def test_TransformSearchSpace(self) -> None:
ss2 = self.search_space.clone()
Expand Down

0 comments on commit 7a0ffb5

Please sign in to comment.