Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure new arm names do not match a different name on the experiment #2732

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,10 @@ def _check_existing_and_name_arm(self, arm: Arm) -> None:
experiment, uses the existing arm name.
"""
proposed_name = self._get_default_name()

# Arm could already be in experiment, replacement is okay.
self.experiment._name_and_store_arm_if_not_exists(
arm=arm, proposed_name=proposed_name
arm=arm, proposed_name=proposed_name, replace=True
)
# If arm was named using given name, incremement the count
if arm.name == proposed_name:
Expand Down
4 changes: 3 additions & 1 deletion ax/core/batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,9 @@ def set_status_quo_with_weight(
status_quo.parameters, raise_error=True
)
self.experiment._name_and_store_arm_if_not_exists(
arm=status_quo, proposed_name="status_quo_" + str(self.index)
arm=status_quo,
proposed_name="status_quo_" + str(self.index),
replace=True,
)
self._status_quo = status_quo.clone() if status_quo is not None else None
self._status_quo_weight_override = weight
Expand Down
24 changes: 22 additions & 2 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
from ax.core.trial import Trial
from ax.core.types import ComparisonOp, TParameterization
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.exceptions.core import AxError, UnsupportedError, UserInputError
from ax.utils.common.base import Base
from ax.utils.common.constants import EXPERIMENT_IS_TEST_WARNING, Keys
from ax.utils.common.docutils import copy_doc
Expand Down Expand Up @@ -1350,7 +1350,9 @@ def warm_start_from_old_experiment(

return copied_trials

def _name_and_store_arm_if_not_exists(self, arm: Arm, proposed_name: str) -> None:
def _name_and_store_arm_if_not_exists(
self, arm: Arm, proposed_name: str, replace: bool = False
) -> None:
"""Tries to lookup arm with same signature, otherwise names and stores it.

- Looks up if arm already exists on experiment
Expand All @@ -1360,6 +1362,8 @@ def _name_and_store_arm_if_not_exists(self, arm: Arm, proposed_name: str) -> Non
Args:
arm: The arm object to name.
proposed_name: The name to assign if it doesn't have one already.
replace: If true, override arm w/ same name and different signature.
If false, raise an error if this conflict occurs.
"""

# If arm is identical to an existing arm, return that
Expand All @@ -1377,6 +1381,22 @@ def _name_and_store_arm_if_not_exists(self, arm: Arm, proposed_name: str) -> Non
else:
if not arm.has_name:
arm.name = proposed_name

# Check for signature conflict by arm name/proposed name
if (
arm.name in self.arms_by_name
and arm.signature != self.arms_by_name[arm.name].signature
):
error_msg = (
f"Arm with name {arm.name} already exists on experiment "
+ "with different signature."
)
if replace:
logger.warning(f"{error_msg} Replacing the existing arm. ")
else:
raise AxError(error_msg)

# Add the new arm
self._register_arm(arm)

def _register_arm(self, arm: Arm) -> None:
Expand Down
42 changes: 41 additions & 1 deletion ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from ax.core.search_space import SearchSpace
from ax.core.types import ComparisonOp
from ax.exceptions.core import UnsupportedError
from ax.exceptions.core import AxError, UnsupportedError
from ax.metrics.branin import BraninMetric
from ax.modelbridge.registry import Models
from ax.runners.synthetic import SyntheticRunner
Expand Down Expand Up @@ -1541,3 +1541,43 @@ class TestAuxiliaryExperimentPurpose(AuxiliaryExperimentPurpose):
],
},
)

def test_name_and_store_arm_if_not_exists_same_name_different_signature(
self,
) -> None:
experiment = self.experiment
shared_name = "shared_name"

arm_1 = Arm({"x1": -1.0, "x2": 1.0}, name=shared_name)
arm_2 = Arm({"x1": -1.7, "x2": 0.2, "x3": 1})
self.assertNotEqual(arm_1.signature, arm_2.signature)

experiment._register_arm(arm=arm_1)
with self.assertRaisesRegex(
AxError,
f"Arm with name {shared_name} already exists on experiment "
f"with different signature.",
):
experiment._name_and_store_arm_if_not_exists(
arm=arm_2, proposed_name=shared_name
)

def test_name_and_store_arm_if_not_exists_same_proposed_name_different_signature(
self,
) -> None:
experiment = self.experiment
shared_name = "shared_name"

arm_1 = Arm({"x1": -1.0, "x2": 1.0}, name=shared_name)
arm_2 = Arm({"x1": -1.7, "x2": 0.2, "x3": 1}, name=shared_name)
self.assertNotEqual(arm_1.signature, arm_2.signature)

experiment._register_arm(arm=arm_1)
with self.assertRaisesRegex(
AxError,
f"Arm with name {shared_name} already exists on experiment "
f"with different signature.",
):
experiment._name_and_store_arm_if_not_exists(
arm=arm_2, proposed_name="different proposed name"
)