Skip to content

Commit

Permalink
use all constraints (outcome constraints + objective thresholds) to d…
Browse files Browse the repository at this point in the history
…etermine feasibility (#1313)

Summary:
Pull Request resolved: #1313

Also copies index when returning "all true" when no constraints present.
Zachary Cohen reported repeat index resulting from the lack of index copy (f393832851). Issue with objective thresholds found upon further inspection.

Reviewed By: mpolson64

Differential Revision: D41888559

fbshipit-source-id: 25f0604f846de295f681980e32a0421321cbb82d
  • Loading branch information
Bernie Beckerman authored and facebook-github-bot committed Dec 19, 2022
1 parent 502f5b0 commit 79221c6
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 12 deletions.
8 changes: 4 additions & 4 deletions ax/service/tests/test_best_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ def test_get_trace(self) -> None:

# Multi objective.
exp = get_experiment_with_observations(
observations=[[1, 1], [1, 2], [3, 3], [2, 4], [2, 1]],
observations=[[1, 1], [-1, 100], [1, 2], [3, 3], [2, 4], [2, 1]],
)
self.assertEqual(get_trace(exp), [1, 2, 9, 11, 11])
self.assertEqual(get_trace(exp), [1, 1, 2, 9, 11, 11])

# W/ constraints.
exp = get_experiment_with_observations(
observations=[[1, 1, 1], [1, 2, -1], [3, 3, -1], [2, 4, 1], [2, 1, 1]],
observations=[[-1, 1, 1], [1, 2, 1], [3, 3, -1], [2, 4, 1], [2, 1, 1]],
constrained=True,
)
self.assertEqual(get_trace(exp), [1, 1, 1, 8, 8])
self.assertEqual(get_trace(exp), [0, 2, 2, 8, 8])

# W/ first objective being minimized.
exp = get_experiment_with_observations(
Expand Down
6 changes: 3 additions & 3 deletions ax/service/tests/test_best_point_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_best_from_model_prediction(self) -> None:
)
),
) as mock_model_best_point, self.assertLogs(
logger="ax.service.utils.best_point", level="WARN"
logger=best_point_logger, level="WARN"
) as lg:
# Test bad model fit causes function to resort back to raw data
with patch(
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_best_raw_objective_point(self) -> None:
get_best_raw_objective_point(exp, opt_conf)

# Test constraints work as expected.
observations = [[1.0, 2.0], [3.0, 4.0], [5.0, -6.0]]
observations = [[1.0, 2.0], [3.0, 4.0], [-5.0, -6.0]]
exp = get_experiment_with_observations(
observations=observations,
constrained=True,
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_best_raw_objective_point_unsatisfiable_relative(self) -> None:
trial.mark_completed()
exp.fetch_data()

with self.assertLogs(logger="ax.service.utils.best_point", level="WARN") as lg:
with self.assertLogs(logger=best_point_logger, level="WARN") as lg:
get_best_raw_objective_point(exp, opt_conf)
self.assertTrue(
any("No status quo provided" in warning for warning in lg.output),
Expand Down
8 changes: 4 additions & 4 deletions ax/service/utils/best_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,8 @@ def _is_row_feasible(
falls outside of any outcome constraint's bounds (i.e. we are 95% sure the
bound is not satisfied), else True.
"""
if len(optimization_config.outcome_constraints) < 1:
return pd.Series([True] * len(df))
if len(optimization_config.all_constraints) < 1:
return pd.Series([True] * len(df), index=df.index)

name = df["metric_name"]

Expand All @@ -652,7 +652,7 @@ def _is_row_feasible(
rel_lower_bound = None
rel_upper_bound = None
if status_quo is not None and any(
oc.relative for oc in optimization_config.outcome_constraints
oc.relative for oc in optimization_config.all_constraints
):
# relativize_data expects all arms to come from the same trial, we need to
# format the data as if it was.
Expand Down Expand Up @@ -707,7 +707,7 @@ def oc_mask(oc: OutcomeConstraint) -> pd.Series:

mask = reduce(
lambda left, right: left & right,
map(oc_mask, optimization_config.outcome_constraints),
map(oc_mask, optimization_config.all_constraints),
)
bad_arm_names = (
df[~mask]["arm_name"].tolist()
Expand Down
9 changes: 8 additions & 1 deletion ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,14 @@ def get_experiment_with_observations(
optimization_config = MultiObjectiveOptimizationConfig(
objective=MultiObjective(metrics=metrics),
objective_thresholds=[
ObjectiveThreshold(metric=metrics[i], bound=0.0, relative=False)
ObjectiveThreshold(
metric=metrics[i],
bound=0.0,
relative=False,
op=ComparisonOp.LEQ
if metrics[i].lower_is_better
else ComparisonOp.GEQ,
)
for i in [0, 1]
],
outcome_constraints=[
Expand Down

0 comments on commit 79221c6

Please sign in to comment.