Skip to content

Commit

Permalink
Changed basic_examples to use LightningCLI (#6862)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
mauvilsa and carmocca authored Apr 15, 2021
1 parent f645df5 commit f852a4f
Show file tree
Hide file tree
Showing 15 changed files with 203 additions and 233 deletions.
4 changes: 2 additions & 2 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ jobs:
set -e
python -m pytest pl_examples -v --maxfail=2 --durations=0
pip install . --user --quiet
bash pl_examples/run_examples-args.sh --gpus 1 --max_epochs 1 --batch_size 64 --limit_train_batches 5 --limit_val_batches 3
bash pl_examples/run_ddp-examples.sh --max_epochs 1 --batch_size 32 --limit_train_batches 2 --limit_val_batches 2
bash pl_examples/run_examples-args.sh --trainer.gpus 1 --trainer.max_epochs 1 --data.batch_size 64 --trainer.limit_train_batches 5 --trainer.limit_val_batches 3
bash pl_examples/run_ddp-examples.sh --trainer.max_epochs 1 --data.batch_size 32 --trainer.limit_train_batches 2 --trainer.limit_val_batches 2
# cd pl_examples/basic_examples
# bash submit_ddp_job.sh
# bash submit_ddp2_job.sh
Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ datamodule class. However, there are many cases in which the objective is to eas
multiple models and datasets. For these cases the tool can be configured such that a model and/or a datamodule is
specified by an import path and init arguments. For example, with a tool implemented as:

.. testcode::
.. code-block:: python
from pytorch_lightning.utilities.cli import LightningCLI
Expand Down
12 changes: 6 additions & 6 deletions pl_examples/basic_examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ Trains MNIST where the model is defined inside the `LightningModule`.
python simple_image_classifier.py

# gpus (any number)
python simple_image_classifier.py --gpus 2
python simple_image_classifier.py --trainer.gpus 2

# dataparallel
python simple_image_classifier.py --gpus 2 --distributed_backend 'dp'
python simple_image_classifier.py --trainer.gpus 2 --trainer.accelerator 'dp'
```

---
Expand All @@ -30,10 +30,10 @@ Generic image classifier with an arbitrary backbone (ie: a simple system)
python backbone_image_classifier.py

# gpus (any number)
python backbone_image_classifier.py --gpus 2
python backbone_image_classifier.py --trainer.gpus 2

# dataparallel
python backbone_image_classifier.py --gpus 2 --distributed_backend 'dp'
python backbone_image_classifier.py --trainer.gpus 2 --trainer.accelerator 'dp'
```

---
Expand All @@ -44,10 +44,10 @@ Showing the power of a system... arbitrarily complex training loops
python autoencoder.py

# gpus (any number)
python autoencoder.py --gpus 2
python autoencoder.py --trainer.gpus 2

# dataparallel
python autoencoder.py --gpus 2 --distributed_backend 'dp'
python autoencoder.py --trainer.gpus 2 --trainer.accelerator 'dp'
```
---
# Multi-node example
Expand Down
68 changes: 30 additions & 38 deletions pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MNIST autoencoder example.
from argparse import ArgumentParser
To run:
python autoencoder.py --trainer.max_epochs=50
"""

import torch
import torch.nn.functional as F
Expand All @@ -21,6 +25,7 @@

import pytorch_lightning as pl
from pl_examples import _DATASETS_PATH, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
Expand Down Expand Up @@ -87,44 +92,31 @@ def configure_optimizers(self):
return optimizer


class MyDataModule(pl.LightningDataModule):

def __init__(
self,
batch_size: int = 32,
):
super().__init__()
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
self.mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
self.batch_size = batch_size

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)


def cli_main():
pl.seed_everything(1234)

# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--hidden_dim', type=int, default=64)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()

# ------------
# data
# ------------
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)

# ------------
# model
# ------------
model = LitAutoEncoder(args.hidden_dim)

# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)

# ------------
# testing
# ------------
result = trainer.test(test_dataloaders=test_loader)
cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)


Expand Down
89 changes: 44 additions & 45 deletions pl_examples/basic_examples/backbone_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MNIST backbone image classifier example.
from argparse import ArgumentParser
To run:
python backbone_image_classifier.py --trainer.max_epochs=50
"""

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

import pytorch_lightning as pl
from pl_examples import _DATASETS_PATH, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
Expand Down Expand Up @@ -59,7 +64,11 @@ class LitClassifier(pl.LightningModule):
)
"""

def __init__(self, backbone, learning_rate=1e-3):
def __init__(
self,
backbone,
learning_rate: float = 0.0001,
):
super().__init__()
self.save_hyperparameters()
self.backbone = backbone
Expand Down Expand Up @@ -92,52 +101,42 @@ def configure_optimizers(self):
# self.hparams available because we called self.save_hyperparameters()
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("LitClassifier")
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parent_parser

class MyDataModule(pl.LightningDataModule):

def __init__(
self,
batch_size: int = 32,
):
super().__init__()
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
self.mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
self.batch_size = batch_size

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)


class MyLightningCLI(LightningCLI):

def add_arguments_to_parser(self, parser):
parser.add_class_arguments(Backbone, 'model.backbone')

def instantiate_model(self):
self.config_init['model']['backbone'] = Backbone(**self.config['model']['backbone'])
super().instantiate_model()


def cli_main():
pl.seed_everything(1234)

# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--hidden_dim', type=int, default=128)
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
args = parser.parse_args()

# ------------
# data
# ------------
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)

# ------------
# model
# ------------
model = LitClassifier(Backbone(hidden_dim=args.hidden_dim), args.learning_rate)

# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)

# ------------
# testing
# ------------
result = trainer.test(test_dataloaders=test_loader)
cli = MyLightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)


Expand Down
Loading

0 comments on commit f852a4f

Please sign in to comment.