From 3471fb5aef2be27b94dd646ba69eafed9c1b21c2 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Wed, 7 Apr 2021 08:31:28 +0200 Subject: [PATCH 1/9] Changed simple_image_classifier.py to use LightningCLI --- .../basic_examples/simple_image_classifier.py | 66 +++++++------------ 1 file changed, 22 insertions(+), 44 deletions(-) diff --git a/pl_examples/basic_examples/simple_image_classifier.py b/pl_examples/basic_examples/simple_image_classifier.py index 3f7079d665ea8..08d549527a917 100644 --- a/pl_examples/basic_examples/simple_image_classifier.py +++ b/pl_examples/basic_examples/simple_image_classifier.py @@ -11,8 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +MNIST simple image classifier example. + +To run: +python simple_image_classifier.py --trainer.max_epochs=50 +""" -from argparse import ArgumentParser from pprint import pprint import torch @@ -21,6 +26,7 @@ import pytorch_lightning as pl from pl_examples import cli_lightning_logo from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule +from pytorch_lightning.utilities.cli import LightningCLI class LitClassifier(pl.LightningModule): @@ -32,7 +38,11 @@ class LitClassifier(pl.LightningModule): ) """ - def __init__(self, hidden_dim=128, learning_rate=1e-3): + def __init__( + self, + hidden_dim: int = 128, + learning_rate: float = 0.0001, + ): super().__init__() self.save_hyperparameters() @@ -66,49 +76,17 @@ def test_step(self, batch, batch_idx): def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - @staticmethod - def add_model_specific_args(parent_parser): - parser = parent_parser.add_argument_group("LitClassifier") - parser.add_argument('--hidden_dim', type=int, default=128) - parser.add_argument('--learning_rate', type=float, default=0.0001) - return parent_parser - - -def cli_main(): - pl.seed_everything(1234) - - # ------------ - # args - # ------------ - parser = ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) - parser = LitClassifier.add_model_specific_args(parser) - parser = MNISTDataModule.add_argparse_args(parser) - args = parser.parse_args() - - # ------------ - # data - # ------------ - dm = MNISTDataModule.from_argparse_args(args) - - # ------------ - # model - # ------------ - model = LitClassifier(args.hidden_dim, args.learning_rate) - - # ------------ - # training - # ------------ - trainer = pl.Trainer.from_argparse_args(args) - trainer.fit(model, datamodule=dm) - - # ------------ - # testing - # ------------ - result = trainer.test(model, datamodule=dm) - pprint(result) + +class MyLightningCLI(LightningCLI): + + def before_instantiate_classes(self): + pl.seed_everything(1234) + + def after_fit(self): + result = self.trainer.test(self.model, datamodule=self.datamodule) + pprint(result) if __name__ == '__main__': cli_lightning_logo() - cli_main() + MyLightningCLI(LitClassifier, MNISTDataModule) From 4462f6bb4b30aefd0a5c3cf0836611ce162c524b Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Fri, 9 Apr 2021 15:39:59 +0200 Subject: [PATCH 2/9] - Added fail_untyped, seed_everything, subclass required needed for the examples. - Increased minimum required jsonargparse version to 3.9.0. - Improvements to simple_image_classifier.py example. - Changed autoencoder.py and backbone_image_classifier.py examples to use LightningCLI. - Updated pl_examples/test_examples.py so that tests succeed. --- pl_examples/basic_examples/autoencoder.py | 78 +++++++-------- .../backbone_image_classifier.py | 96 ++++++++++--------- .../basic_examples/simple_image_classifier.py | 10 +- pl_examples/test_examples.py | 18 ++-- pytorch_lightning/utilities/cli.py | 19 ++-- requirements/extra.txt | 2 +- tests/utilities/test_cli.py | 2 + 7 files changed, 119 insertions(+), 106 deletions(-) diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index 6841b8555ef1f..58ec3cdfe2bb8 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +MNIST autoencoder example. -from argparse import ArgumentParser +To run: +python autoencoder.py --trainer.max_epochs=50 +""" import torch import torch.nn.functional as F @@ -21,6 +25,7 @@ import pytorch_lightning as pl from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo +from pytorch_lightning.utilities.cli import LightningCLI if _TORCHVISION_AVAILABLE: from torchvision import transforms @@ -86,45 +91,40 @@ def configure_optimizers(self): return optimizer +class MyDataModule(pl.LightningDataModule): + + def __init__( + self, + batch_size: int = 32, + ): + super().__init__() + dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor()) + self.mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor()) + self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000]) + self.batch_size = batch_size + + def train_dataloader(self): + return DataLoader(self.mnist_train, batch_size=self.batch_size) + + def val_dataloader(self): + return DataLoader(self.mnist_val, batch_size=self.batch_size) + + def test_dataloader(self): + return DataLoader(self.mnist_test, batch_size=self.batch_size) + + +class MyLightningCLI(LightningCLI): + + def before_parse_arguments(self, parser): + parser.set_defaults(seed_everything=1234) + + def after_fit(self): + result = self.trainer.test(test_dataloaders=self.datamodule.test_dataloader()) + print(result) + + def cli_main(): - pl.seed_everything(1234) - - # ------------ - # args - # ------------ - parser = ArgumentParser() - parser.add_argument('--batch_size', default=32, type=int) - parser.add_argument('--hidden_dim', type=int, default=64) - parser = pl.Trainer.add_argparse_args(parser) - args = parser.parse_args() - - # ------------ - # data - # ------------ - dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor()) - mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor()) - mnist_train, mnist_val = random_split(dataset, [55000, 5000]) - - train_loader = DataLoader(mnist_train, batch_size=args.batch_size) - val_loader = DataLoader(mnist_val, batch_size=args.batch_size) - test_loader = DataLoader(mnist_test, batch_size=args.batch_size) - - # ------------ - # model - # ------------ - model = LitAutoEncoder(args.hidden_dim) - - # ------------ - # training - # ------------ - trainer = pl.Trainer.from_argparse_args(args) - trainer.fit(model, train_loader, val_loader) - - # ------------ - # testing - # ------------ - result = trainer.test(test_dataloaders=test_loader) - print(result) + MyLightningCLI(LitAutoEncoder, MyDataModule) if __name__ == '__main__': diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index 1c78d264a8681..969ee6ab6e5ad 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +MNIST backbone image classifier example. -from argparse import ArgumentParser +To run: +python backbone_image_classifier.py --trainer.max_epochs=50 +""" import torch from torch.nn import functional as F @@ -20,6 +24,7 @@ import pytorch_lightning as pl from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo +from pytorch_lightning.utilities.cli import LightningCLI if _TORCHVISION_AVAILABLE: from torchvision import transforms @@ -58,7 +63,11 @@ class LitClassifier(pl.LightningModule): ) """ - def __init__(self, backbone, learning_rate=1e-3): + def __init__( + self, + backbone, + learning_rate: float = 0.0001, + ): super().__init__() self.save_hyperparameters() self.backbone = backbone @@ -91,53 +100,48 @@ def configure_optimizers(self): # self.hparams available because we called self.save_hyperparameters() return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - @staticmethod - def add_model_specific_args(parent_parser): - parser = parent_parser.add_argument_group("LitClassifier") - parser.add_argument('--learning_rate', type=float, default=0.0001) - return parent_parser + +class MyDataModule(pl.LightningDataModule): + + def __init__( + self, + batch_size: int = 32, + ): + super().__init__() + dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor()) + self.mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor()) + self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000]) + self.batch_size = batch_size + + def train_dataloader(self): + return DataLoader(self.mnist_train, batch_size=self.batch_size) + + def val_dataloader(self): + return DataLoader(self.mnist_val, batch_size=self.batch_size) + + def test_dataloader(self): + return DataLoader(self.mnist_test, batch_size=self.batch_size) + + +class MyLightningCLI(LightningCLI): + + def add_arguments_to_parser(self, parser): + parser.add_class_arguments(Backbone, 'model.backbone') + + def before_parse_arguments(self, parser): + parser.set_defaults(seed_everything=1234) + + def instantiate_model(self): + self.config_init['model']['backbone'] = Backbone(**self.config['model']['backbone']) + super().instantiate_model() + + def after_fit(self): + result = self.trainer.test(test_dataloaders=self.datamodule.test_dataloader()) + print(result) def cli_main(): - pl.seed_everything(1234) - - # ------------ - # args - # ------------ - parser = ArgumentParser() - parser.add_argument('--batch_size', default=32, type=int) - parser.add_argument('--hidden_dim', type=int, default=128) - parser = pl.Trainer.add_argparse_args(parser) - parser = LitClassifier.add_model_specific_args(parser) - args = parser.parse_args() - - # ------------ - # data - # ------------ - dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor()) - mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor()) - mnist_train, mnist_val = random_split(dataset, [55000, 5000]) - - train_loader = DataLoader(mnist_train, batch_size=args.batch_size) - val_loader = DataLoader(mnist_val, batch_size=args.batch_size) - test_loader = DataLoader(mnist_test, batch_size=args.batch_size) - - # ------------ - # model - # ------------ - model = LitClassifier(Backbone(hidden_dim=args.hidden_dim), args.learning_rate) - - # ------------ - # training - # ------------ - trainer = pl.Trainer.from_argparse_args(args) - trainer.fit(model, train_loader, val_loader) - - # ------------ - # testing - # ------------ - result = trainer.test(test_dataloaders=test_loader) - print(result) + MyLightningCLI(LitClassifier, MyDataModule) if __name__ == '__main__': diff --git a/pl_examples/basic_examples/simple_image_classifier.py b/pl_examples/basic_examples/simple_image_classifier.py index 08d549527a917..ffb0b83d66f6d 100644 --- a/pl_examples/basic_examples/simple_image_classifier.py +++ b/pl_examples/basic_examples/simple_image_classifier.py @@ -79,14 +79,18 @@ def configure_optimizers(self): class MyLightningCLI(LightningCLI): - def before_instantiate_classes(self): - pl.seed_everything(1234) + def before_parse_arguments(self, parser): + parser.set_defaults(seed_everything=1234) def after_fit(self): result = self.trainer.test(self.model, datamodule=self.datamodule) pprint(result) +def cli_main(): + MyLightningCLI(LitClassifier, MNISTDataModule) + + if __name__ == '__main__': cli_lightning_logo() - MyLightningCLI(LitClassifier, MNISTDataModule) + cli_main() diff --git a/pl_examples/test_examples.py b/pl_examples/test_examples.py index b930957a26346..1d9cd5d9f6d8d 100644 --- a/pl_examples/test_examples.py +++ b/pl_examples/test_examples.py @@ -22,24 +22,24 @@ from pl_examples import _DALI_AVAILABLE ARGS_DEFAULT = """ ---default_root_dir %(tmpdir)s \ ---max_epochs 1 \ ---batch_size 32 \ ---limit_train_batches 2 \ ---limit_val_batches 2 \ +--trainer.default_root_dir %(tmpdir)s \ +--trainer.max_epochs 1 \ +--trainer.limit_train_batches 2 \ +--trainer.limit_val_batches 2 \ +--data.batch_size 32 \ """ ARGS_GPU = ARGS_DEFAULT + """ ---gpus 1 \ +--trainer.gpus 1 \ """ ARGS_DP = ARGS_DEFAULT + """ ---gpus 2 \ ---accelerator dp \ +--trainer.gpus 2 \ +--trainer.accelerator dp \ """ ARGS_AMP = """ ---precision 16 \ +--trainer.precision 16 \ """ diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 33c424829606a..97fd75ccbed08 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -13,13 +13,14 @@ # limitations under the License. import os from argparse import Namespace -from typing import Any, Dict, Type, Union +from typing import Any, Dict, Optional, Type, Union from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.utilities import _module_available +from pytorch_lightning.utilities.seed import seed_everything _JSONARGPARSE_AVAILABLE = _module_available("jsonargparse") if _JSONARGPARSE_AVAILABLE: @@ -63,8 +64,8 @@ def add_lightning_class_args( """ assert issubclass(lightning_class, (Trainer, LightningModule, LightningDataModule)) if subclass_mode: - return self.add_subclass_arguments(lightning_class, nested_key) - return self.add_class_arguments(lightning_class, nested_key) + return self.add_subclass_arguments(lightning_class, nested_key, required=True) + return self.add_class_arguments(lightning_class, nested_key, fail_untyped=False) class SaveConfigCallback(Callback): @@ -161,6 +162,8 @@ def __init__( self.add_core_arguments_to_parser() self.before_parse_arguments(self.parser) self.parse_arguments() + if self.config['seed_everything'] is not None: + seed_everything(self.config['seed_everything']) self.before_instantiate_classes() self.instantiate_classes() self.prepare_fit_kwargs() @@ -178,10 +181,14 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: Args: parser: The argument parser object to which arguments should be added """ - pass def add_core_arguments_to_parser(self) -> None: """Adds arguments from the core classes to the parser""" + self.parser.add_argument( + '--seed_everything', + type=Optional[int], + help='Set to an int to run seed_everything with this value before classes instantiation', + ) self.parser.add_lightning_class_args(self.trainer_class, 'trainer') trainer_defaults = {'trainer.' + k: v for k, v in self.trainer_defaults.items() if k != 'callbacks'} self.parser.set_defaults(trainer_defaults) @@ -195,7 +202,6 @@ def before_parse_arguments(self, parser: LightningArgumentParser) -> None: Args: parser: The argument parser object that will be used to parse """ - pass def parse_arguments(self) -> None: """Parses command line arguments and stores it in self.config""" @@ -203,7 +209,6 @@ def parse_arguments(self) -> None: def before_instantiate_classes(self) -> None: """Implement to run some code before instantiating the classes""" - pass def instantiate_classes(self) -> None: """Instantiates the classes using settings from self.config""" @@ -249,7 +254,6 @@ def prepare_fit_kwargs(self) -> None: def before_fit(self) -> None: """Implement to run some code before fit is started""" - pass def fit(self) -> None: """Runs fit of the instantiated trainer class and prepared fit keyword arguments""" @@ -257,4 +261,3 @@ def fit(self) -> None: def after_fit(self) -> None: """Implement to run some code after fit has finished""" - pass diff --git a/requirements/extra.txt b/requirements/extra.txt index 46a726fe05c43..c1dcc1cd4032a 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -9,5 +9,5 @@ onnxruntime>=1.3.0 hydra-core>=1.0 # todo: when switch to standard package stream, drop `fairscale` from hard mocked docs libs https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip -jsonargparse[signatures]>=3.3.1 +jsonargparse[signatures]>=3.9.0 deepspeed>=0.3.13 diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 9b2c825c7aeb8..e6c067266cc8b 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -284,12 +284,14 @@ def test_lightning_cli_args(tmpdir): f'--trainer.default_root_dir={tmpdir}', '--trainer.max_epochs=1', '--trainer.weights_summary=null', + '--seed_everything=1234', ] with mock.patch('sys.argv', ['any.py'] + cli_args): cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={'callbacks': [LearningRateMonitor()]}) assert cli.fit_result == 1 + assert cli.config['seed_everything'] == 1234 config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml' assert os.path.isfile(config_path) with open(config_path) as f: From 51483fc1f0deb2ea87b901cb95c6d744f76b6747 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Fri, 9 Apr 2021 16:19:13 +0200 Subject: [PATCH 3/9] Disabled testcode for LightningCLI subclasses since this requires a config --- docs/source/common/lightning_cli.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 6d8ae6701fd40..a5cdf9f93e14f 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -224,7 +224,7 @@ datamodule class. However, there are many cases in which the objective is to eas multiple models and datasets. For these cases the tool can be configured such that a model and/or a datamodule is specified by an import path and init arguments. For example, with a tool implemented as: -.. testcode:: +.. code-block:: python from pytorch_lightning.utilities.cli import LightningCLI From 5873a2fa32a3f17aae9b65e4fe3ab50844e644b4 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Tue, 13 Apr 2021 08:17:35 +0200 Subject: [PATCH 4/9] - Added seed_everything_default to LightningCLI. - Changed dali_image_classifier.py example to use LightningCLI. - Simplified use of LightningCLI in autoencoder.py, backbone_image_classifier.py and simple_image_classifier.py. - Adapted other files related to the changed examples. --- pl_examples/basic_examples/README.md | 12 +-- pl_examples/basic_examples/autoencoder.py | 14 +-- .../backbone_image_classifier.py | 11 +-- .../basic_examples/dali_image_classifier.py | 91 ++++++++----------- .../basic_examples/simple_image_classifier.py | 16 +--- pl_examples/basic_examples/submit_ddp2_job.sh | 6 +- pl_examples/basic_examples/submit_ddp_job.sh | 6 +- pl_examples/run_ddp-examples.sh | 4 +- pytorch_lightning/utilities/cli.py | 4 + 9 files changed, 68 insertions(+), 96 deletions(-) diff --git a/pl_examples/basic_examples/README.md b/pl_examples/basic_examples/README.md index b02ea21c7940d..9c37c98a315c7 100644 --- a/pl_examples/basic_examples/README.md +++ b/pl_examples/basic_examples/README.md @@ -8,10 +8,10 @@ Trains MNIST where the model is defined inside the `LightningModule`. python simple_image_classifier.py # gpus (any number) -python simple_image_classifier.py --gpus 2 +python simple_image_classifier.py --trainer.gpus 2 # dataparallel -python simple_image_classifier.py --gpus 2 --distributed_backend 'dp' +python simple_image_classifier.py --trainer.gpus 2 --trainer.distributed_backend 'dp' ``` --- @@ -30,10 +30,10 @@ Generic image classifier with an arbitrary backbone (ie: a simple system) python backbone_image_classifier.py # gpus (any number) -python backbone_image_classifier.py --gpus 2 +python backbone_image_classifier.py --trainer.gpus 2 # dataparallel -python backbone_image_classifier.py --gpus 2 --distributed_backend 'dp' +python backbone_image_classifier.py --trainer.gpus 2 --trainer.distributed_backend 'dp' ``` --- @@ -44,10 +44,10 @@ Showing the power of a system... arbitrarily complex training loops python autoencoder.py # gpus (any number) -python autoencoder.py --gpus 2 +python autoencoder.py --trainer.gpus 2 # dataparallel -python autoencoder.py --gpus 2 --distributed_backend 'dp' +python autoencoder.py --trainer.gpus 2 --trainer.distributed_backend 'dp' ``` --- # Multi-node example diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index 58ec3cdfe2bb8..3a016eb93269c 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -113,18 +113,10 @@ def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=self.batch_size) -class MyLightningCLI(LightningCLI): - - def before_parse_arguments(self, parser): - parser.set_defaults(seed_everything=1234) - - def after_fit(self): - result = self.trainer.test(test_dataloaders=self.datamodule.test_dataloader()) - print(result) - - def cli_main(): - MyLightningCLI(LitAutoEncoder, MyDataModule) + cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234) + result = cli.trainer.test(cli.model, datamodule=cli.datamodule) + print(result) if __name__ == '__main__': diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index 969ee6ab6e5ad..68b3145e32fb4 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -128,20 +128,15 @@ class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_class_arguments(Backbone, 'model.backbone') - def before_parse_arguments(self, parser): - parser.set_defaults(seed_everything=1234) - def instantiate_model(self): self.config_init['model']['backbone'] = Backbone(**self.config['model']['backbone']) super().instantiate_model() - def after_fit(self): - result = self.trainer.test(test_dataloaders=self.datamodule.test_dataloader()) - print(result) - def cli_main(): - MyLightningCLI(LitClassifier, MyDataModule) + cli = MyLightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234) + result = cli.trainer.test(cli.model, datamodule=cli.datamodule) + print(result) if __name__ == '__main__': diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index 08bf64da252bf..e3eb3026db97a 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC -from argparse import ArgumentParser from distutils.version import LooseVersion from random import shuffle from warnings import warn @@ -30,6 +29,7 @@ _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo, ) +from pytorch_lightning.utilities.cli import LightningCLI if _TORCHVISION_AVAILABLE: from torchvision import transforms @@ -136,7 +136,11 @@ def __len__(self): class LitClassifier(pl.LightningModule): - def __init__(self, hidden_dim=128, learning_rate=1e-3): + def __init__( + self, + hidden_dim: int = 128, + learning_rate: float = 0.0001, + ): super().__init__() self.save_hyperparameters() @@ -173,64 +177,43 @@ def test_step(self, batch, batch_idx): def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - @staticmethod - def add_model_specific_args(parent_parser): - parser = parent_parser.add_argument_group("LitClassifier") - parser.add_argument('--hidden_dim', type=int, default=128) - parser.add_argument('--learning_rate', type=float, default=0.0001) - return parent_parser + +class MyDataModule(pl.LightningDataModule): + + def __init__( + self, + batch_size: int = 32, + ): + super().__init__() + dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor()) + self.mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor()) + self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000]) + + eii_train = ExternalMNISTInputIterator(self.mnist_train, batch_size) + eii_val = ExternalMNISTInputIterator(self.mnist_val, batch_size) + eii_test = ExternalMNISTInputIterator(self.mnist_test, batch_size) + + self.pipe_train = ExternalSourcePipeline(batch_size=batch_size, eii=eii_train, num_threads=2, device_id=0) + self.pipe_val = ExternalSourcePipeline(batch_size=batch_size, eii=eii_val, num_threads=2, device_id=0) + self.pipe_test = ExternalSourcePipeline(batch_size=batch_size, eii=eii_test, num_threads=2, device_id=0) + + def train_dataloader(self): + return DALIClassificationLoader(self.pipe_train, size=len(self.mnist_train), auto_reset=True, fill_last_batch=True) + + def val_dataloader(self): + return DALIClassificationLoader(self.pipe_val, size=len(self.mnist_val), auto_reset=True, fill_last_batch=False) + + def test_dataloader(self): + return DALIClassificationLoader(self.pipe_test, size=len(self.mnist_test), auto_reset=True, fill_last_batch=False) def cli_main(): if not _DALI_AVAILABLE: return - pl.seed_everything(1234) - - # ------------ - # args - # ------------ - parser = ArgumentParser() - parser.add_argument('--batch_size', default=32, type=int) - parser = pl.Trainer.add_argparse_args(parser) - parser = LitClassifier.add_model_specific_args(parser) - args = parser.parse_args() - - # ------------ - # data - # ------------ - dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor()) - mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor()) - mnist_train, mnist_val = random_split(dataset, [55000, 5000]) - - eii_train = ExternalMNISTInputIterator(mnist_train, args.batch_size) - eii_val = ExternalMNISTInputIterator(mnist_val, args.batch_size) - eii_test = ExternalMNISTInputIterator(mnist_test, args.batch_size) - - pipe_train = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_train, num_threads=2, device_id=0) - train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=True) - - pipe_val = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_val, num_threads=2, device_id=0) - val_loader = DALIClassificationLoader(pipe_val, size=len(mnist_val), auto_reset=True, fill_last_batch=False) - - pipe_test = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_test, num_threads=2, device_id=0) - test_loader = DALIClassificationLoader(pipe_test, size=len(mnist_test), auto_reset=True, fill_last_batch=False) - - # ------------ - # model - # ------------ - model = LitClassifier(args.hidden_dim, args.learning_rate) - - # ------------ - # training - # ------------ - trainer = pl.Trainer.from_argparse_args(args) - trainer.fit(model, train_loader, val_loader) - - # ------------ - # testing - # ------------ - trainer.test(test_dataloaders=test_loader) + cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234) + result = cli.trainer.test(cli.model, datamodule=cli.datamodule) + print(result) if __name__ == "__main__": diff --git a/pl_examples/basic_examples/simple_image_classifier.py b/pl_examples/basic_examples/simple_image_classifier.py index ffb0b83d66f6d..d401e884a2f18 100644 --- a/pl_examples/basic_examples/simple_image_classifier.py +++ b/pl_examples/basic_examples/simple_image_classifier.py @@ -18,8 +18,6 @@ python simple_image_classifier.py --trainer.max_epochs=50 """ -from pprint import pprint - import torch from torch.nn import functional as F @@ -77,18 +75,10 @@ def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) -class MyLightningCLI(LightningCLI): - - def before_parse_arguments(self, parser): - parser.set_defaults(seed_everything=1234) - - def after_fit(self): - result = self.trainer.test(self.model, datamodule=self.datamodule) - pprint(result) - - def cli_main(): - MyLightningCLI(LitClassifier, MNISTDataModule) + cli = LightningCLI(LitClassifier, MNISTDataModule, seed_everything_default=1234) + result = cli.trainer.test(cli.model, datamodule=cli.datamodule) + print(result) if __name__ == '__main__': diff --git a/pl_examples/basic_examples/submit_ddp2_job.sh b/pl_examples/basic_examples/submit_ddp2_job.sh index 026589a604c36..e31c7a185b73c 100755 --- a/pl_examples/basic_examples/submit_ddp2_job.sh +++ b/pl_examples/basic_examples/submit_ddp2_job.sh @@ -24,4 +24,8 @@ source activate $1 # ------------------------- # run script from above -srun python3 simple_image_classifier.py --accelerator 'ddp2' --gpus 2 --num_nodes 2 --max_epochs 5 +srun python3 simple_image_classifier.py \ + --trainer.accelerator 'ddp2' \ + --trainer.gpus 2 \ + --trainer.num_nodes 2 \ + --trainer.max_epochs 5 diff --git a/pl_examples/basic_examples/submit_ddp_job.sh b/pl_examples/basic_examples/submit_ddp_job.sh index b4f5ff0a64d92..177c19b3fdd72 100755 --- a/pl_examples/basic_examples/submit_ddp_job.sh +++ b/pl_examples/basic_examples/submit_ddp_job.sh @@ -24,4 +24,8 @@ source activate $1 # ------------------------- # run script from above -srun python3 simple_image_classifier.py --accelerator 'ddp' --gpus 2 --num_nodes 2 --max_epochs 5 +srun python3 simple_image_classifier.py \ + --trainer.accelerator 'ddp' \ + --trainer.gpus 2 \ + --trainer.num_nodes 2 \ + --trainer.max_epochs 5 diff --git a/pl_examples/run_ddp-examples.sh b/pl_examples/run_ddp-examples.sh index 6cc36364e397d..6a1a63b5205e7 100644 --- a/pl_examples/run_ddp-examples.sh +++ b/pl_examples/run_ddp-examples.sh @@ -1,7 +1,7 @@ #!/bin/bash -ARGS_EXTRA_DDP=" --gpus 2 --accelerator ddp" -ARGS_EXTRA_AMP=" --precision 16" +ARGS_EXTRA_DDP=" --trainer.gpus 2 --trainer.accelerator ddp" +ARGS_EXTRA_AMP=" --trainer.precision 16" python pl_examples/basic_examples/simple_image_classifier.py $@ ${ARGS_EXTRA_DDP} python pl_examples/basic_examples/simple_image_classifier.py $@ ${ARGS_EXTRA_DDP} ${ARGS_EXTRA_AMP} diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 97fd75ccbed08..93729a53db25d 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -97,6 +97,7 @@ def __init__( save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback, trainer_class: Type[Trainer] = Trainer, trainer_defaults: Dict[str, Any] = None, + seed_everything_default: int = None, description: str = 'pytorch-lightning trainer command line tool', env_prefix: str = 'PL', env_parse: bool = False, @@ -132,6 +133,7 @@ def __init__( save_config_callback: A callback class to save the training config. trainer_class: An optional extension of the Trainer class. trainer_defaults: Set to override Trainer defaults or add persistent callbacks. + seed_everything_default: Default value for seed_everything argument. description: Description of the tool shown when running --help. env_prefix: Prefix for environment variables. env_parse: Whether environment variable parsing is enabled. @@ -152,6 +154,7 @@ def __init__( self.save_config_callback = save_config_callback self.trainer_class = trainer_class self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults + self.seed_everything_default = seed_everything_default self.subclass_mode_model = subclass_mode_model self.subclass_mode_data = subclass_mode_data self.parser_kwargs = {} if parser_kwargs is None else parser_kwargs @@ -187,6 +190,7 @@ def add_core_arguments_to_parser(self) -> None: self.parser.add_argument( '--seed_everything', type=Optional[int], + default=self.seed_everything_default, help='Set to an int to run seed_everything with this value before classes instantiation', ) self.parser.add_lightning_class_args(self.trainer_class, 'trainer') From 79e8b27293224499692e748442fd4860f973f39b Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Tue, 13 Apr 2021 08:30:45 +0200 Subject: [PATCH 5/9] Fix pep8 issues in dali_image_classifier.py --- .../basic_examples/dali_image_classifier.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index e3eb3026db97a..a4db77115f647 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -198,13 +198,28 @@ def __init__( self.pipe_test = ExternalSourcePipeline(batch_size=batch_size, eii=eii_test, num_threads=2, device_id=0) def train_dataloader(self): - return DALIClassificationLoader(self.pipe_train, size=len(self.mnist_train), auto_reset=True, fill_last_batch=True) + return DALIClassificationLoader( + self.pipe_train, + size=len(self.mnist_train), + auto_reset=True, + fill_last_batch=True + ) def val_dataloader(self): - return DALIClassificationLoader(self.pipe_val, size=len(self.mnist_val), auto_reset=True, fill_last_batch=False) + return DALIClassificationLoader( + self.pipe_val, + size=len(self.mnist_val), + auto_reset=True, + fill_last_batch=False + ) def test_dataloader(self): - return DALIClassificationLoader(self.pipe_test, size=len(self.mnist_test), auto_reset=True, fill_last_batch=False) + return DALIClassificationLoader( + self.pipe_test, + size=len(self.mnist_test), + auto_reset=True, + fill_last_batch=False + ) def cli_main(): From 1ed66b60f45474ad6eb9933cdf72c14cbf9e600b Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Tue, 13 Apr 2021 10:03:28 +0200 Subject: [PATCH 6/9] Updated azure-pipelines.yml --- azure-pipelines.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 85664bac74b67..574ef265fbfb4 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -117,8 +117,8 @@ jobs: set -e python -m pytest pl_examples -v --maxfail=2 --durations=0 pip install . --user --quiet - bash pl_examples/run_examples-args.sh --gpus 1 --max_epochs 1 --batch_size 64 --limit_train_batches 5 --limit_val_batches 3 - bash pl_examples/run_ddp-examples.sh --max_epochs 1 --batch_size 32 --limit_train_batches 2 --limit_val_batches 2 + bash pl_examples/run_examples-args.sh --trainer.gpus 1 --trainer.max_epochs 1 --data.batch_size 64 --trainer.limit_train_batches 5 --limit_val_batches 3 + bash pl_examples/run_ddp-examples.sh --trainer.max_epochs 1 --data.batch_size 32 --trainer.limit_train_batches 2 --trainer.limit_val_batches 2 # cd pl_examples/basic_examples # bash submit_ddp_job.sh # bash submit_ddp2_job.sh From 34acebaafebd5a23ae7657aaf19a654c41342725 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Tue, 13 Apr 2021 10:53:02 +0200 Subject: [PATCH 7/9] Updated azure-pipelines.yml --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 574ef265fbfb4..bd11b05812a4f 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -117,7 +117,7 @@ jobs: set -e python -m pytest pl_examples -v --maxfail=2 --durations=0 pip install . --user --quiet - bash pl_examples/run_examples-args.sh --trainer.gpus 1 --trainer.max_epochs 1 --data.batch_size 64 --trainer.limit_train_batches 5 --limit_val_batches 3 + bash pl_examples/run_examples-args.sh --trainer.gpus 1 --trainer.max_epochs 1 --data.batch_size 64 --trainer.limit_train_batches 5 --trainer.limit_val_batches 3 bash pl_examples/run_ddp-examples.sh --trainer.max_epochs 1 --data.batch_size 32 --trainer.limit_train_batches 2 --trainer.limit_val_batches 2 # cd pl_examples/basic_examples # bash submit_ddp_job.sh From f2fcafbc7a217f5aae6ac023108390bc395aa71e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 13 Apr 2021 13:43:14 +0200 Subject: [PATCH 8/9] distributed_backend -> accelerator --- pl_examples/basic_examples/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pl_examples/basic_examples/README.md b/pl_examples/basic_examples/README.md index 9c37c98a315c7..ffacf6895e3c8 100644 --- a/pl_examples/basic_examples/README.md +++ b/pl_examples/basic_examples/README.md @@ -11,7 +11,7 @@ python simple_image_classifier.py python simple_image_classifier.py --trainer.gpus 2 # dataparallel -python simple_image_classifier.py --trainer.gpus 2 --trainer.distributed_backend 'dp' +python simple_image_classifier.py --trainer.gpus 2 --trainer.accelerator 'dp' ``` --- @@ -33,7 +33,7 @@ python backbone_image_classifier.py python backbone_image_classifier.py --trainer.gpus 2 # dataparallel -python backbone_image_classifier.py --trainer.gpus 2 --trainer.distributed_backend 'dp' +python backbone_image_classifier.py --trainer.gpus 2 --trainer.accelerator 'dp' ``` --- @@ -47,7 +47,7 @@ python autoencoder.py python autoencoder.py --trainer.gpus 2 # dataparallel -python autoencoder.py --trainer.gpus 2 --trainer.distributed_backend 'dp' +python autoencoder.py --trainer.gpus 2 --trainer.accelerator 'dp' ``` --- # Multi-node example From c2bfea33711e343a2889be54e12359b462a24515 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Tue, 13 Apr 2021 19:42:56 +0200 Subject: [PATCH 9/9] Changed profiler_example.py to use LightningCLI. --- .../basic_examples/profiler_example.py | 38 ++++++++----------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/pl_examples/basic_examples/profiler_example.py b/pl_examples/basic_examples/profiler_example.py index ca640a96f9588..0a6f855b2f109 100644 --- a/pl_examples/basic_examples/profiler_example.py +++ b/pl_examples/basic_examples/profiler_example.py @@ -23,7 +23,6 @@ """ import sys -from argparse import ArgumentParser import torch import torchvision @@ -31,27 +30,26 @@ import torchvision.transforms as T from pl_examples import cli_lightning_logo -from pytorch_lightning import LightningDataModule, LightningModule, Trainer +from pytorch_lightning import LightningDataModule, LightningModule +from pytorch_lightning.utilities.cli import LightningCLI DEFAULT_CMD_LINE = ( - "--max_epochs", - "1", - "--limit_train_batches", - "15", - "--limit_val_batches", - "15", - "--profiler", - "pytorch", - "--gpus", - f"{int(torch.cuda.is_available())}", + "--trainer.max_epochs=1", + "--trainer.limit_train_batches=15", + "--trainer.limit_val_batches=15", + "--trainer.profiler=pytorch", + f"--trainer.gpus={int(torch.cuda.is_available())}", ) class ModelToProfile(LightningModule): - def __init__(self, model): + def __init__( + self, + name: str = "resnet50" + ): super().__init__() - self.model = model + self.model = getattr(models, name)(pretrained=True) self.criterion = torch.nn.CrossEntropyLoss() def training_step(self, batch, batch_idx): @@ -85,16 +83,10 @@ def val_dataloader(self, *args, **kwargs): def cli_main(): + if len(sys.argv) == 1: + sys.argv += DEFAULT_CMD_LINE - parser = ArgumentParser() - parser = Trainer.add_argparse_args(parser) - cmd_line = None if len(sys.argv) != 1 else DEFAULT_CMD_LINE - args = parser.parse_args(args=cmd_line) - - model = ModelToProfile(models.resnet50(pretrained=True)) - datamodule = CIFAR10DataModule() - trainer = Trainer(**vars(args)) - trainer.fit(model, datamodule=datamodule) + LightningCLI(ModelToProfile, CIFAR10DataModule) if __name__ == '__main__':