From cf2f19562df92e11924a7f0de69b3f5855d14fe1 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Thu, 27 Jun 2024 22:07:57 -0700 Subject: [PATCH] Downgrade error to beta-warning for gs selection in AxClient Summary: Previously we added this error in D56048665 in response to a github issue, however now that we are adding support for GenerationStrategy selection for BatchTrials this can be a warning that this method is in development. In the future we should be able to fully remove the method Differential Revision: D59143308 --- ax/service/ax_client.py | 6 +++--- ax/telemetry/tests/test_ax_client.py | 16 +++++++++++----- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 4370e3ab86a..ff5b988675c 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -1732,9 +1732,9 @@ def _set_generation_strategy( "use_batch_trials" in choose_generation_strategy_kwargs and type(self) is AxClient ): - raise UnsupportedError( - "AxClient API does not support batch trials yet." - " We plan to add this support in coming versions." + logger.warning( + "Selecting a GenerationStrategy when using BatchTrials is in beta. " + "Double check the recommended strategy matches your expectations." ) random_seed = choose_generation_strategy_kwargs.pop( "random_seed", self._random_seed diff --git a/ax/telemetry/tests/test_ax_client.py b/ax/telemetry/tests/test_ax_client.py index 5a7ee677a6c..56ba5887f5b 100644 --- a/ax/telemetry/tests/test_ax_client.py +++ b/ax/telemetry/tests/test_ax_client.py @@ -6,6 +6,8 @@ # pyre-strict +import logging +from logging import Logger from typing import Dict, List, Sequence, Union import numpy as np @@ -16,8 +18,11 @@ from ax.telemetry.ax_client import AxClientCompletedRecord, AxClientCreatedRecord from ax.telemetry.experiment import ExperimentCompletedRecord, ExperimentCreatedRecord from ax.telemetry.generation_strategy import GenerationStrategyCreatedRecord +from ax.utils.common.logger import get_logger from ax.utils.common.testutils import TestCase +logger: Logger = get_logger(__name__) + class TestAxClient(TestCase): def test_ax_client_created_record_from_ax_client(self) -> None: @@ -133,11 +138,8 @@ def test_ax_client_completed_record_from_ax_client(self) -> None: def test_batch_trial_warning(self) -> None: ax_client = AxClient() - error_msg = ( - "AxClient API does not support batch trials yet." - " We plan to add this support in coming versions." - ) - with self.assertRaisesRegex(UnsupportedError, error_msg): + warning_msg = "GenerationStrategy when using BatchTrials is in beta." + with self.assertLogs(AxClient.__module__, logging.WARNING) as logger: ax_client.create_experiment( name="test_experiment", parameters=[ @@ -149,6 +151,10 @@ def test_batch_trial_warning(self) -> None: "use_batch_trials": True, }, ) + self.assertTrue( + any(warning_msg in output for output in logger.output), + logger.output, + ) def _compare_axclient_completed_records( self, record: AxClientCompletedRecord, expected: AxClientCompletedRecord