Skip to content

Commit

Permalink
In 'best_point' don't require GeneratorRun to have best_arm_predictio…
Browse files Browse the repository at this point in the history
…ns to predict from model (#2767)

Summary:
Pull Request resolved: #2767

Context:

`get_best_parameters_from_model_predictions_with_trial_index` will only predict from a model if there are `best_arm_predictions` on the `GeneratorRun`. This doesn't make sense, since it's about to construct and fit a new model and use it to generate predicts. Any existing `best_arm_predictions` are not used.

This PR:
* Removes the `gr.best_arm_predictions is not None` check
* Changes how some imported functions are referenced in `best_point_mixin.py` (doesn't change functionality)

Reviewed By: mpolson64

Differential Revision: D62594017

fbshipit-source-id: 25b07ec5668dc24572d5121a2db37b181d7bb36c
  • Loading branch information
esantorella authored and facebook-github-bot committed Sep 13, 2024
1 parent 81c945f commit 2fc80f1
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
12 changes: 4 additions & 8 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from ax.utils.testing.mock import fast_botorch_optimize
from ax.utils.testing.modeling_stubs import get_observation1, get_observation1trans
from botorch.test_functions.multi_objective import BraninCurrin
from pyre_extensions import none_throws

if TYPE_CHECKING:
from ax.core.types import TTrialEvaluation
Expand Down Expand Up @@ -1477,20 +1478,15 @@ def test_update_running_trial_with_intermediate_data(self) -> None:
)
def test_get_best_point_no_model_predictions(
self,
# pyre-fixme[2]: Parameter must be annotated.
mock_get_best_parameters_from_model_predictions_with_trial_index,
mock_get_best_parameters_from_model_predictions_with_trial_index: Mock,
) -> None:
ax_client = get_branin_optimization()
params, idx = ax_client.get_next_trial()
ax_client.complete_trial(trial_index=idx, raw_data={"branin": (0, 0.0)})
# pyre-fixme[23]: Unable to unpack `Optional[Tuple[int, Dict[str,
# typing.Union[None, bool, float, int, str]], Optional[Tuple[Dict[str, float],
# Optional[Dict[str, typing.Dict[str, float]]]]]]]` into 3 values.
best_idx, best_params, _ = ax_client.get_best_trial()
best_idx, best_params, _ = none_throws(ax_client.get_best_trial())
self.assertEqual(best_idx, idx)
self.assertEqual(best_params, params)
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
self.assertEqual(ax_client.get_best_parameters()[0], params)
self.assertEqual(none_throws(ax_client.get_best_parameters())[0], params)
mock_get_best_parameters_from_model_predictions_with_trial_index.assert_called()
mock_get_best_parameters_from_model_predictions_with_trial_index.reset_mock()
ax_client.get_best_parameters(use_model_predictions=False)
Expand Down
9 changes: 8 additions & 1 deletion ax/service/tests/test_best_point_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,14 @@ def test_best_from_model_prediction(self) -> None:
mock_model_best_point.assert_called()

# Assert the non-mocked method works correctly as well
self.assertIsNotNone(get_best_parameters(exp, Models))
best_params = get_best_parameters(exp, Models)
self.assertIsNotNone(best_params)
# It works even when there are no predictions already stored on the
# GeneratorRun
for trial in exp.trials.values():
trial.generator_run._best_arm_predictions = None
best_params_no_gr = get_best_parameters(exp, Models)
self.assertEqual(best_params, best_params_no_gr)

def test_best_raw_objective_point(self) -> None:
exp = get_branin_experiment()
Expand Down
2 changes: 1 addition & 1 deletion ax/service/utils/best_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def get_best_parameters_from_model_predictions_with_trial_index(
# In theory batch_trial can have >1 gr, grab the first
gr = trial.generator_run_structs[0].generator_run

if gr is not None and gr.best_arm_predictions is not None:
if gr is not None:
data = experiment.lookup_data(trial_indices=trial_indices)

try:
Expand Down

0 comments on commit 2fc80f1

Please sign in to comment.