Skip to content

Commit

Permalink
Avoid CLI warning when configure_optimizers will not be overridden
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Sep 17, 2021
1 parent 856ed10 commit 7916e24
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
12 changes: 6 additions & 6 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,19 +662,14 @@ def get_automatic(
"#optimizers-and-learning-rate-schedulers"
)

if is_overridden("configure_optimizers", self.model):
warnings._warn(
f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by "
f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`."
)

optimizer_class = parser._optimizers[optimizers[0]][0]
optimizer_init = self._get(self.config_init, optimizers[0])
if not isinstance(optimizer_class, tuple):
optimizer_init = _global_add_class_path(optimizer_class, optimizer_init)
if not optimizer_init:
# optimizers were registered automatically but not passed by the user
return

lr_scheduler_init = None
if lr_schedulers:
lr_scheduler_class = parser._lr_schedulers[lr_schedulers[0]][0]
Expand All @@ -691,6 +686,11 @@ def configure_optimizers(
lr_scheduler = instantiate_class(optimizer, lr_scheduler_init)
return [optimizer], [lr_scheduler]

if is_overridden("configure_optimizers", self.model):
warnings._warn(
f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by "
f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`."
)
self.model.configure_optimizers = MethodType(configure_optimizers, self.model)

def _get(self, config: Dict[str, Any], key: str, default: Optional[Any] = None) -> Any:
Expand Down
9 changes: 9 additions & 0 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
from tests.helpers.utils import no_warning_call

torchvision_version = version.parse("0")
if _TORCHVISION_AVAILABLE:
Expand Down Expand Up @@ -1265,3 +1266,11 @@ class TestCallback(Callback):
)
# the existing config is not updated
assert cli.config_init["trainer"]["max_epochs"] is None


def test_cli_configure_optimizers_warning(tmpdir):
match = "configure_optimizers` will be overridden by `LightningCLI"
with mock.patch("sys.argv", ["any.py"]), no_warning_call(UserWarning, match=match):
LightningCLI(BoringModel, run=False)
with mock.patch("sys.argv", ["any.py", "--optimizer=Adam"]), pytest.warns(UserWarning, match=match):
LightningCLI(BoringModel, run=False)

0 comments on commit 7916e24

Please sign in to comment.