Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed basic_examples to use LightningCLI #6862

Merged
merged 12 commits into from
Apr 15, 2021
Merged
2 changes: 1 addition & 1 deletion docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ datamodule class. However, there are many cases in which the objective is to eas
multiple models and datasets. For these cases the tool can be configured such that a model and/or a datamodule is
specified by an import path and init arguments. For example, with a tool implemented as:

.. testcode::
.. code-block:: python

from pytorch_lightning.utilities.cli import LightningCLI

Expand Down
78 changes: 39 additions & 39 deletions pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MNIST autoencoder example.

from argparse import ArgumentParser
To run:
python autoencoder.py --trainer.max_epochs=50
"""

import torch
import torch.nn.functional as F
Expand All @@ -21,6 +25,7 @@

import pytorch_lightning as pl
from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
from pytorch_lightning.utilities.cli import LightningCLI

if _TORCHVISION_AVAILABLE:
from torchvision import transforms
Expand Down Expand Up @@ -86,45 +91,40 @@ def configure_optimizers(self):
return optimizer


class MyDataModule(pl.LightningDataModule):

def __init__(
self,
batch_size: int = 32,
):
super().__init__()
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
self.mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
self.batch_size = batch_size

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)


class MyLightningCLI(LightningCLI):

def before_parse_arguments(self, parser):
parser.set_defaults(seed_everything=1234)

def after_fit(self):
result = self.trainer.test(test_dataloaders=self.datamodule.test_dataloader())
print(result)


def cli_main():
pl.seed_everything(1234)

# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--hidden_dim', type=int, default=64)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()

# ------------
# data
# ------------
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)

# ------------
# model
# ------------
model = LitAutoEncoder(args.hidden_dim)

# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)

# ------------
# testing
# ------------
result = trainer.test(test_dataloaders=test_loader)
print(result)
MyLightningCLI(LitAutoEncoder, MyDataModule)


if __name__ == '__main__':
Expand Down
96 changes: 50 additions & 46 deletions pl_examples/basic_examples/backbone_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MNIST backbone image classifier example.

from argparse import ArgumentParser
To run:
python backbone_image_classifier.py --trainer.max_epochs=50
"""

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

import pytorch_lightning as pl
from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
from pytorch_lightning.utilities.cli import LightningCLI

if _TORCHVISION_AVAILABLE:
from torchvision import transforms
Expand Down Expand Up @@ -58,7 +63,11 @@ class LitClassifier(pl.LightningModule):
)
"""

def __init__(self, backbone, learning_rate=1e-3):
def __init__(
self,
backbone,
learning_rate: float = 0.0001,
):
super().__init__()
self.save_hyperparameters()
self.backbone = backbone
Expand Down Expand Up @@ -91,53 +100,48 @@ def configure_optimizers(self):
# self.hparams available because we called self.save_hyperparameters()
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("LitClassifier")
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parent_parser

class MyDataModule(pl.LightningDataModule):

def __init__(
self,
batch_size: int = 32,
):
super().__init__()
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
self.mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
self.batch_size = batch_size

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)


class MyLightningCLI(LightningCLI):

def add_arguments_to_parser(self, parser):
parser.add_class_arguments(Backbone, 'model.backbone')

def before_parse_arguments(self, parser):
parser.set_defaults(seed_everything=1234)

def instantiate_model(self):
self.config_init['model']['backbone'] = Backbone(**self.config['model']['backbone'])
super().instantiate_model()

def after_fit(self):
result = self.trainer.test(test_dataloaders=self.datamodule.test_dataloader())
print(result)


def cli_main():
pl.seed_everything(1234)

# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--hidden_dim', type=int, default=128)
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
args = parser.parse_args()

# ------------
# data
# ------------
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)

# ------------
# model
# ------------
model = LitClassifier(Backbone(hidden_dim=args.hidden_dim), args.learning_rate)

# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)

# ------------
# testing
# ------------
result = trainer.test(test_dataloaders=test_loader)
print(result)
MyLightningCLI(LitClassifier, MyDataModule)


if __name__ == '__main__':
Expand Down
62 changes: 22 additions & 40 deletions pl_examples/basic_examples/simple_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MNIST simple image classifier example.

To run:
python simple_image_classifier.py --trainer.max_epochs=50
"""

from argparse import ArgumentParser
from pprint import pprint

import torch
Expand All @@ -21,6 +26,7 @@
import pytorch_lightning as pl
from pl_examples import cli_lightning_logo
from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule
from pytorch_lightning.utilities.cli import LightningCLI


class LitClassifier(pl.LightningModule):
Expand All @@ -32,7 +38,11 @@ class LitClassifier(pl.LightningModule):
)
"""

def __init__(self, hidden_dim=128, learning_rate=1e-3):
def __init__(
self,
hidden_dim: int = 128,
learning_rate: float = 0.0001,
):
super().__init__()
self.save_hyperparameters()

Expand Down Expand Up @@ -66,47 +76,19 @@ def test_step(self, batch, batch_idx):
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("LitClassifier")
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parent_parser

class MyLightningCLI(LightningCLI):

def before_parse_arguments(self, parser):
parser.set_defaults(seed_everything=1234)
mauvilsa marked this conversation as resolved.
Show resolved Hide resolved

def after_fit(self):
result = self.trainer.test(self.model, datamodule=self.datamodule)
mauvilsa marked this conversation as resolved.
Show resolved Hide resolved
pprint(result)


def cli_main():
pl.seed_everything(1234)

# ------------
# args
# ------------
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
parser = MNISTDataModule.add_argparse_args(parser)
args = parser.parse_args()

# ------------
# data
# ------------
dm = MNISTDataModule.from_argparse_args(args)

# ------------
# model
# ------------
model = LitClassifier(args.hidden_dim, args.learning_rate)

# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, datamodule=dm)

# ------------
# testing
# ------------
result = trainer.test(model, datamodule=dm)
pprint(result)
MyLightningCLI(LitClassifier, MNISTDataModule)


if __name__ == '__main__':
Expand Down
18 changes: 9 additions & 9 deletions pl_examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,24 @@
from pl_examples import _DALI_AVAILABLE

ARGS_DEFAULT = """
--default_root_dir %(tmpdir)s \
--max_epochs 1 \
--batch_size 32 \
--limit_train_batches 2 \
--limit_val_batches 2 \
--trainer.default_root_dir %(tmpdir)s \
--trainer.max_epochs 1 \
--trainer.limit_train_batches 2 \
--trainer.limit_val_batches 2 \
--data.batch_size 32 \
"""

ARGS_GPU = ARGS_DEFAULT + """
--gpus 1 \
--trainer.gpus 1 \
"""

ARGS_DP = ARGS_DEFAULT + """
--gpus 2 \
--accelerator dp \
--trainer.gpus 2 \
--trainer.accelerator dp \
"""

ARGS_AMP = """
--precision 16 \
--trainer.precision 16 \
"""


Expand Down
Loading