diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index d97ef9ccddebb..7b064aa4fbbca 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -662,12 +662,6 @@ def get_automatic( "#optimizers-and-learning-rate-schedulers" ) - if is_overridden("configure_optimizers", self.model): - warnings._warn( - f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by " - f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`." - ) - optimizer_class = parser._optimizers[optimizers[0]][0] optimizer_init = self._get(self.config_init, optimizers[0]) if not isinstance(optimizer_class, tuple): @@ -675,6 +669,7 @@ def get_automatic( if not optimizer_init: # optimizers were registered automatically but not passed by the user return + lr_scheduler_init = None if lr_schedulers: lr_scheduler_class = parser._lr_schedulers[lr_schedulers[0]][0] @@ -691,6 +686,11 @@ def configure_optimizers( lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) return [optimizer], [lr_scheduler] + if is_overridden("configure_optimizers", self.model): + warnings._warn( + f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by " + f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`." + ) self.model.configure_optimizers = MethodType(configure_optimizers, self.model) def _get(self, config: Dict[str, Any], key: str, default: Optional[Any] = None) -> Any: diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index bff5d7e9111e4..8c5d441c0925c 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -46,6 +46,7 @@ from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf +from tests.helpers.utils import no_warning_call torchvision_version = version.parse("0") if _TORCHVISION_AVAILABLE: @@ -1265,3 +1266,11 @@ class TestCallback(Callback): ) # the existing config is not updated assert cli.config_init["trainer"]["max_epochs"] is None + + +def test_cli_configure_optimizers_warning(tmpdir): + match = "configure_optimizers` will be overridden by `LightningCLI" + with mock.patch("sys.argv", ["any.py"]), no_warning_call(UserWarning, match=match): + LightningCLI(BoringModel, run=False) + with mock.patch("sys.argv", ["any.py", "--optimizer=Adam"]), pytest.warns(UserWarning, match=match): + LightningCLI(BoringModel, run=False)