From 76e7ba8cf9f6e769a6468e72af8053263ea96f8c Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 08:43:46 -0700 Subject: [PATCH 1/4] Stricter argument hygiene for passing datamodules to trainer.fit or trainer.tune --- pytorch_lightning/trainer/trainer.py | 14 +++++++++++++- tests/deprecated_api/test_remove_1-5.py | 15 +++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 12fc1c3288325..f4dd0a8e78b74 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -58,8 +58,9 @@ from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import DeviceType, parsing, rank_zero_warn +from pytorch_lightning.utilities import DeviceType, parsing from pytorch_lightning.utilities.debugging import InternalDebugger +from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.model_helpers import is_overridden @@ -438,7 +439,12 @@ def fit( self.training = True # if a datamodule comes in as the second arg, then fix it for the user + # deprecated in v1.3 and will be removed in v1.5 if isinstance(train_dataloader, LightningDataModule): + rank_zero_deprecation( + "Passing the datamodule without using named arguments is deprecated in v1.3 and will be removed in v1.5." + " Pass the datamodule explicitly to trainer.fit(..., datamodule=dm)" + ) datamodule = train_dataloader train_dataloader = None # If you supply a datamodule you can't supply train_dataloader or val_dataloaders @@ -662,9 +668,15 @@ def tune( self.tuning = True # if a datamodule comes in as the second arg, then fix it for the user + # deprecated in v1.3 and will be removed in v1.5 if isinstance(train_dataloader, LightningDataModule): + rank_zero_deprecation( + "Passing the datamodule without using named arguments is deprecated in v1.3 and will be removed in v1.5." + " Pass the datamodule explicitly to trainer.fit(..., datamodule=dm)" + ) datamodule = train_dataloader train_dataloader = None + # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: raise MisconfigurationException( diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 7285c5d176444..bf330ecbe61f9 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -380,3 +380,18 @@ def test_v1_5_0_lighting_module_grad_norm(tmpdir): model = BoringModel() with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): model.grad_norm(2) + + +def test_v1_5_0_datamodule_named_argument(tmpdir): + model = BoringModel() + dm = BoringDataModule() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + checkpoint_callback=False, + logger=False, + ) + with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): + trainer.fit(model, dm) + with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): + trainer.tune(model, dm) From bcc8ff25b2c10d105226d49b8e6cb2491faf3bec Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 08:47:55 -0700 Subject: [PATCH 2/4] chlog --- CHANGELOG.md | 3 +++ tests/deprecated_api/test_remove_1-5.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b629e9e72a9a5..19f4f2e9e7cd7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -193,6 +193,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated passing datamodule as an unnamed argument to `Trainer.fit` and `Trainer.tune` in favor of explicitly specifying `Trainer.fit/tune(model, datamodule=dm)` ([#7239](https://github.com/PyTorchLightning/pytorch-lightning/pull/7329/)) + + - Deprecated `LightningModule.grad_norm` in favor of `pytorch_lightning.utilities.grads.grad_norm` ([#7292](https://github.com/PyTorchLightning/pytorch-lightning/pull/7292)) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index bf330ecbe61f9..0054a683767f3 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -27,7 +27,7 @@ from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache from tests.deprecated_api import no_deprecated_call -from tests.helpers import BoringModel +from tests.helpers import BoringDataModule, BoringModel from tests.helpers.utils import no_warning_call From 99ff5d0f9571093548cb056e1d45c966ce5e55d2 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 08:51:56 -0700 Subject: [PATCH 3/4] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f4dd0a8e78b74..ecf44586deb23 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -442,7 +442,8 @@ def fit( # deprecated in v1.3 and will be removed in v1.5 if isinstance(train_dataloader, LightningDataModule): rank_zero_deprecation( - "Passing the datamodule without using named arguments is deprecated in v1.3 and will be removed in v1.5." + "Passing the datamodule without using named arguments is deprecated in v1.3 " + " and will be removed in v1.5." " Pass the datamodule explicitly to trainer.fit(..., datamodule=dm)" ) datamodule = train_dataloader @@ -671,7 +672,8 @@ def tune( # deprecated in v1.3 and will be removed in v1.5 if isinstance(train_dataloader, LightningDataModule): rank_zero_deprecation( - "Passing the datamodule without using named arguments is deprecated in v1.3 and will be removed in v1.5." + "Passing the datamodule without using named arguments is deprecated in v1.3 " + " and will be removed in v1.5." " Pass the datamodule explicitly to trainer.fit(..., datamodule=dm)" ) datamodule = train_dataloader From 33ba6bbe022ccd0992585844f0ce4443752b93b6 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 09:07:24 -0700 Subject: [PATCH 4/4] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ecf44586deb23..95703f5723d27 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -442,7 +442,7 @@ def fit( # deprecated in v1.3 and will be removed in v1.5 if isinstance(train_dataloader, LightningDataModule): rank_zero_deprecation( - "Passing the datamodule without using named arguments is deprecated in v1.3 " + "Passing the datamodule without using named arguments is deprecated in v1.3" " and will be removed in v1.5." " Pass the datamodule explicitly to trainer.fit(..., datamodule=dm)" ) @@ -672,7 +672,7 @@ def tune( # deprecated in v1.3 and will be removed in v1.5 if isinstance(train_dataloader, LightningDataModule): rank_zero_deprecation( - "Passing the datamodule without using named arguments is deprecated in v1.3 " + "Passing the datamodule without using named arguments is deprecated in v1.3" " and will be removed in v1.5." " Pass the datamodule explicitly to trainer.fit(..., datamodule=dm)" )