-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
autoencoder.py
195 lines (160 loc) · 7.03 KB
/
autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MNIST autoencoder example.
To run: python autoencoder.py --trainer.max_epochs=50
"""
from os import path
from typing import Optional
import torch
import torch.nn.functional as F
from lightning.pytorch import LightningDataModule, LightningModule, Trainer, callbacks, cli_lightning_logo
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.mnist_datamodule import MNIST
from lightning.pytorch.utilities import rank_zero_only
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
from torch import nn
from torch.utils.data import DataLoader, random_split
if _TORCHVISION_AVAILABLE:
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
class ImageSampler(callbacks.Callback):
def __init__(
self,
num_samples: int = 3,
nrow: int = 8,
padding: int = 2,
normalize: bool = True,
value_range: Optional[tuple[int, int]] = None,
scale_each: bool = False,
pad_value: int = 0,
) -> None:
"""
Args:
num_samples: Number of images displayed in the grid. Default: ``3``.
nrow: Number of images displayed in each row of the grid.
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
padding: Amount of padding. Default: ``2``.
normalize: If ``True``, shift the image to the range (0, 1),
by the min and max values specified by :attr:`range`. Default: ``False``.
value_range: Tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max
are computed from the tensor.
scale_each: If ``True``, scale each image in the batch of
images separately rather than the (min, max) over all images. Default: ``False``.
pad_value: Value for the padded pixels. Default: ``0``.
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")
super().__init__()
self.num_samples = num_samples
self.nrow = nrow
self.padding = padding
self.normalize = normalize
self.value_range = value_range
self.scale_each = scale_each
self.pad_value = pad_value
def _to_grid(self, images):
return torchvision.utils.make_grid(
tensor=images,
nrow=self.nrow,
padding=self.padding,
normalize=self.normalize,
value_range=self.value_range,
scale_each=self.scale_each,
pad_value=self.pad_value,
)
@rank_zero_only
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
if not _TORCHVISION_AVAILABLE:
return
images, _ = next(iter(DataLoader(trainer.datamodule.mnist_val, batch_size=self.num_samples)))
images_flattened = images.view(images.size(0), -1)
# generate images
with torch.no_grad():
pl_module.eval()
images_generated = pl_module(images_flattened.to(pl_module.device))
pl_module.train()
if trainer.current_epoch == 0:
save_image(self._to_grid(images), f"grid_ori_{trainer.current_epoch}.png")
save_image(self._to_grid(images_generated.reshape(images.shape)), f"grid_generated_{trainer.current_epoch}.png")
class LitAutoEncoder(LightningModule):
"""
>>> LitAutoEncoder() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
LitAutoEncoder(
(encoder): ...
(decoder): ...
)
"""
def __init__(self, hidden_dim: int = 64, learning_rate=10e-3):
super().__init__()
self.save_hyperparameters()
self.encoder = nn.Sequential(nn.Linear(28 * 28, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 3))
self.decoder = nn.Sequential(nn.Linear(3, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 28 * 28))
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
def training_step(self, batch, batch_idx):
return self._common_step(batch, batch_idx, "train")
def validation_step(self, batch, batch_idx):
self._common_step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
self._common_step(batch, batch_idx, "test")
def predict_step(self, batch, batch_idx, dataloader_idx=None):
x = self._prepare_batch(batch)
return self(x)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
def _prepare_batch(self, batch):
x, _ = batch
return x.view(x.size(0), -1)
def _common_step(self, batch, batch_idx, stage: str):
x = self._prepare_batch(batch)
loss = F.mse_loss(x, self(x))
self.log(f"{stage}_loss", loss, on_step=True)
return loss
class MyDataModule(LightningDataModule):
def __init__(self, batch_size: int = 32):
super().__init__()
dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
self.mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
self.mnist_train, self.mnist_val = random_split(
dataset, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
self.batch_size = batch_size
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def predict_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def cli_main():
cli = LightningCLI(
LitAutoEncoder,
MyDataModule,
seed_everything_default=1234,
run=False, # used to de-activate automatic fitting.
trainer_defaults={"callbacks": ImageSampler(), "max_epochs": 10},
save_config_kwargs={"overwrite": True},
)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule)
print(predictions[0])
if __name__ == "__main__":
cli_lightning_logo()
cli_main()