Skip to content

Commit

Permalink
yapf examples (#5709)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Jan 30, 2021
1 parent 07f24d2 commit 21d313e
Show file tree
Hide file tree
Showing 17 changed files with 243 additions and 212 deletions.
3 changes: 0 additions & 3 deletions .yapfignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
.git/*

# TODO
pl_examples/*

# TODO
pytorch_lightning/*

Expand Down
1 change: 0 additions & 1 deletion pl_examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
_TORCHVISION_AVAILABLE = _module_available("torchvision")
_DALI_AVAILABLE = _module_available("nvidia.dali")


LIGHTNING_LOGO = """
####
###########
Expand Down
4 changes: 2 additions & 2 deletions pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ def __init__(self):
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 64),
nn.ReLU(),
nn.Linear(64, 3)
nn.Linear(64, 3),
)
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28 * 28)
nn.Linear(64, 28 * 28),
)

def forward(self, x):
Expand Down
2 changes: 2 additions & 0 deletions pl_examples/basic_examples/backbone_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Backbone(torch.nn.Module):
(l2): Linear(...)
)
"""

def __init__(self, hidden_dim=128):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, hidden_dim)
Expand All @@ -55,6 +56,7 @@ class LitClassifier(pl.LightningModule):
(backbone): ...
)
"""

def __init__(self, backbone, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
Expand Down
10 changes: 6 additions & 4 deletions pl_examples/basic_examples/conv_sequential_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@
import pl_bolts
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization


#####################
# Modules #
#####################


class Flatten(nn.Module):

def forward(self, x):
return x.view(x.size(0), -1)


###############################
# LightningModule #
###############################
Expand All @@ -61,6 +62,7 @@ class LitResnet(pl.LightningModule):
(sequential_module): Sequential(...)
)
"""

def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
super().__init__()

Expand Down Expand Up @@ -90,9 +92,7 @@ def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=2, stride=2),

Flatten(),

nn.Dropout(p=0.1),
nn.Linear(4096, 1024),
nn.ReLU(inplace=False),
Expand Down Expand Up @@ -159,7 +159,8 @@ def configure_optimizers(self):
optimizer,
0.1,
epochs=self.trainer.max_epochs,
steps_per_epoch=math.ceil(45000 / self.hparams.batch_size)),
steps_per_epoch=math.ceil(45000 / self.hparams.batch_size)
),
'interval': 'step',
}
}
Expand All @@ -173,6 +174,7 @@ def automatic_optimization(self) -> bool:
# Instantiate Data Module #
#################################


def instantiate_datamodule(args):
train_transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomCrop(32, padding=4),
Expand Down
33 changes: 21 additions & 12 deletions pl_examples/basic_examples/dali_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,22 +95,30 @@ class DALIClassificationLoader(DALIClassificationIterator):
"""

def __init__(
self,
pipelines,
size=-1,
reader_name=None,
auto_reset=False,
fill_last_batch=True,
dynamic_shape=False,
last_batch_padded=False,
self,
pipelines,
size=-1,
reader_name=None,
auto_reset=False,
fill_last_batch=True,
dynamic_shape=False,
last_batch_padded=False,
):
if NEW_DALI_API:
last_batch_policy = LastBatchPolicy.FILL if fill_last_batch else LastBatchPolicy.DROP
super().__init__(pipelines, size, reader_name, auto_reset, dynamic_shape,
last_batch_policy=last_batch_policy, last_batch_padded=last_batch_padded)
super().__init__(
pipelines,
size,
reader_name,
auto_reset,
dynamic_shape,
last_batch_policy=last_batch_policy,
last_batch_padded=last_batch_padded
)
else:
super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch,
dynamic_shape, last_batch_padded)
super().__init__(
pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded
)
self._fill_last_batch = fill_last_batch

def __len__(self):
Expand All @@ -120,6 +128,7 @@ def __len__(self):


class LitClassifier(pl.LightningModule):

def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
Expand Down
12 changes: 7 additions & 5 deletions pl_examples/basic_examples/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ def __init__(
super().__init__(*args, **kwargs)
if num_workers and platform.system() == "Windows":
# see: https://stackoverflow.com/a/59680818
warn(f"You have requested num_workers={num_workers} on Windows,"
" but currently recommended is 0, so we set it for you")
warn(
f"You have requested num_workers={num_workers} on Windows,"
" but currently recommended is 0, so we set it for you"
)
num_workers = 0

self.dims = (1, 28, 28)
Expand Down Expand Up @@ -132,9 +134,9 @@ def default_transforms(self):
if not _TORCHVISION_AVAILABLE:
return None
if self.normalize:
mnist_transforms = transform_lib.Compose(
[transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
)
mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
])
else:
mnist_transforms = transform_lib.ToTensor()

Expand Down
1 change: 1 addition & 0 deletions pl_examples/basic_examples/simple_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class LitClassifier(pl.LightningModule):
(l2): Linear(...)
)
"""

def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
Expand Down
3 changes: 3 additions & 0 deletions pl_examples/bug_report_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class RandomDataset(Dataset):
>>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS
<...bug_report_model.RandomDataset object at ...>
"""

def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
Expand Down Expand Up @@ -124,9 +125,11 @@ def configure_optimizers(self):
# parser = ArgumentParser()
# args = parser.parse_args(opt)


def test_run():

class TestModel(BoringModel):

def on_train_epoch_start(self) -> None:
print('override any method to prove your bug')

Expand Down
51 changes: 21 additions & 30 deletions pl_examples/domain_templates/computer_vision_fine_tuning.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -60,14 +59,12 @@

DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"


# --- Finetunning Callback ---


class MilestonesFinetuningCallback(BaseFinetuningCallback):

def __init__(self,
milestones: tuple = (5, 10),
train_bn: bool = True):
def __init__(self, milestones: tuple = (5, 10), train_bn: bool = True):
self.milestones = milestones
self.train_bn = train_bn

Expand All @@ -78,17 +75,13 @@ def finetunning_function(self, pl_module: pl.LightningModule, epoch: int, optimi
if epoch == self.milestones[0]:
# unfreeze 5 last layers
self.unfreeze_and_add_param_group(
module=pl_module.feature_extractor[-5:],
optimizer=optimizer,
train_bn=self.train_bn
module=pl_module.feature_extractor[-5:], optimizer=optimizer, train_bn=self.train_bn
)

elif epoch == self.milestones[1]:
# unfreeze remaing layers
self.unfreeze_and_add_param_group(
module=pl_module.feature_extractor[:-5],
optimizer=optimizer,
train_bn=self.train_bn
module=pl_module.feature_extractor[:-5], optimizer=optimizer, train_bn=self.train_bn
)


Expand Down Expand Up @@ -149,10 +142,12 @@ def __build_model(self):
self.feature_extractor = nn.Sequential(*_layers)

# 2. Classifier:
_fc_layers = [nn.Linear(2048, 256),
nn.ReLU(),
nn.Linear(256, 32),
nn.Linear(32, 1)]
_fc_layers = [
nn.Linear(2048, 256),
nn.ReLU(),
nn.Linear(256, 32),
nn.Linear(32, 1),
]
self.fc = nn.Sequential(*_fc_layers)

# 3. Loss:
Expand Down Expand Up @@ -218,25 +213,21 @@ def setup(self, stage: str):

train_dataset = ImageFolder(
root=data_path.joinpath("train"),
transform=transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
),
transform=transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]),
)

valid_dataset = ImageFolder(
root=data_path.joinpath("validation"),
transform=transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize,
]
),
transform=transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize,
]),
)

self.train_dataset = train_dataset
Expand Down
33 changes: 16 additions & 17 deletions pl_examples/domain_templates/generative_adversarial_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Generator(nn.Module):
(model): Sequential(...)
)
"""

def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)):
super().__init__()
self.img_shape = img_shape
Expand All @@ -60,7 +61,7 @@ def block(in_feat, out_feat, normalize=True):
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
nn.Tanh(),
)

def forward(self, z):
Expand All @@ -76,6 +77,7 @@ class Discriminator(nn.Module):
(model): Sequential(...)
)
"""

def __init__(self, img_shape):
super().__init__()

Expand Down Expand Up @@ -106,13 +108,14 @@ class GAN(LightningModule):
)
)
"""

def __init__(
self,
img_shape: tuple = (1, 28, 28),
lr: float = 0.0002,
b1: float = 0.5,
b2: float = 0.999,
latent_dim: int = 100,
self,
img_shape: tuple = (1, 28, 28),
lr: float = 0.0002,
b1: float = 0.5,
b2: float = 0.999,
latent_dim: int = 100,
):
super().__init__()

Expand All @@ -130,12 +133,9 @@ def __init__(
def add_argparse_args(parent_parser: ArgumentParser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5,
help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999,
help="adam: decay of second order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100,
help="dimensionality of the latent space")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")

return parser

Expand Down Expand Up @@ -180,8 +180,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
fake = torch.zeros(imgs.size(0), 1)
fake = fake.type_as(imgs)

fake_loss = self.adversarial_loss(
self.discriminator(self(z).detach()), fake)
fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)

# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
Expand Down Expand Up @@ -213,14 +212,14 @@ class MNISTDataModule(LightningDataModule):
>>> MNISTDataModule() # doctest: +ELLIPSIS
<...generative_adversarial_net.MNISTDataModule object at ...>
"""

def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4):
super().__init__()
self.batch_size = batch_size
self.data_path = data_path
self.num_workers = num_workers

self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
self.dims = (1, 28, 28)

def prepare_data(self, stage=None):
Expand Down
Loading

0 comments on commit 21d313e

Please sign in to comment.