-
Notifications
You must be signed in to change notification settings - Fork 2
/
classifier.py
192 lines (165 loc) · 6.88 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from functools import partial
from typing import Union
import pytorch_lightning as pl
import timm.optim
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torchvision.ops import StochasticDepth
from extras import RandomCutMixMixUp
from vision_toolbox import backbones
# https://github.com/pytorch/vision/blob/main/references/classification/train.py
# https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/
def image_loader(path, device="cpu"):
img = torchvision.io.read_file(path)
img = torchvision.io.decode_jpeg(img, device=device)
return img
class ImageClassifier(pl.LightningModule):
def __init__(
self,
# model
backbone: Union[str, backbones.BaseBackbone],
num_classes: int,
include_pool: bool = True,
# augmentation and regularization
mixup_alpha: float = 0.2,
cutmix_alpha: float = 1.0,
# regularization
weight_decay: float = 2e-5,
norm_weight_decay: float = 0,
bias_weight_decay: float = 0,
label_smoothing: float = 0.1,
drop_out: float = None,
drop_path: float = None,
# optimizer and scheduler
optimizer: str = "SGD",
momentum: float = 0.9,
lr: float = 0.05,
decay_factor: float = 0,
warmup_epochs: int = 5,
warmup_factor: float = 0.01,
# others
jit: bool = False,
channels_last: bool = False,
):
super().__init__()
self.save_hyperparameters()
backbone = backbones.__dict__[backbone]() if isinstance(backbone, str) else backbone
layers = [backbone]
if include_pool:
layers.append(nn.AdaptiveAvgPool2d((1, 1)))
layers.append(nn.Flatten())
layers.append(nn.Linear(backbone.get_last_out_channels(), num_classes))
self.model = nn.Sequential(*layers)
self.mixup_cutmix = (
RandomCutMixMixUp(num_classes, cutmix_alpha, mixup_alpha) if cutmix_alpha > 0 and mixup_alpha > 0 else None
)
if drop_out is not None:
for m in self.model.modules():
if isinstance(m, nn.modules.dropout._DropoutNd):
m.p = drop_out
if drop_path is not None:
for m in self.model.modules():
if isinstance(m, StochasticDepth):
m.p = drop_path
if channels_last:
self.model = self.model.to(memory_format=torch.channels_last)
if jit:
self.model = torch.jit.script(self.model)
def training_step(self, batch, batch_idx):
images, labels = batch
if self.mixup_cutmix is not None:
images, labels = self.mixup_cutmix(images, labels)
if self.hparams.channels_last:
images = images.to(memory_format=torch.channels_last)
logits = self.model(images)
loss = F.cross_entropy(logits, labels, label_smoothing=self.hparams.label_smoothing)
self.log("train/loss", loss)
return loss
def validation_step(self, batch, batch_idx):
images, labels = batch
if self.hparams.channels_last:
images = images.to(memory_format=torch.channels_last)
logits = self.model(images)
loss = F.cross_entropy(logits, labels)
self.log("val/loss", loss, sync_dist=True)
preds = torch.argmax(logits, dim=-1)
correct = (labels == preds).sum()
acc = correct / labels.numel()
self.log("val/acc", acc)
def configure_optimizers(self):
# split parameters
# https://github.com/pytorch/vision/blob/main/torchvision/ops/_utils.py
norm_classes = (
nn.modules.batchnorm._BatchNorm,
nn.modules.instancenorm._InstanceNorm,
nn.LayerNorm,
nn.GroupNorm,
)
layer_classes = (nn.Linear, nn.modules.conv._ConvNd)
norm_params = []
bias_params = []
other_params = []
for module in self.modules():
if next(module.children(), None):
other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad)
elif isinstance(module, norm_classes):
norm_params.extend(p for p in module.parameters() if p.requires_grad)
elif isinstance(module, layer_classes):
if module.weight.requires_grad:
other_params.append(module.weight)
if module.bias is not None and module.bias.requires_grad:
bias_params.append(module.bias)
else:
other_params.extend(p for p in module.parameters() if p.requires_grad)
wd = self.hparams.weight_decay
norm_wd = self.hparams.norm_weight_decay
bias_wd = self.hparams.bias_weight_decay
parameters = [
{
"params": norm_params,
"weight_decay": norm_wd if norm_wd is not None else wd,
},
{
"params": bias_params,
"weight_decay": bias_wd if bias_wd is not None else wd,
},
{"params": other_params, "weight_decay": wd},
]
parameters = [x for x in parameters if x["params"]] # remove empty params groups
# build optimizer
optimizer_name = self.hparams.optimizer
lr = self.hparams.lr
momentum = self.hparams.momentum
if optimizer_name in ("SGD", "RMSprop"):
optimizer_cls = partial(getattr(torch.optim, optimizer_name), momentum=momentum)
elif hasattr(torch.optim, optimizer_name):
optimizer_cls = getattr(optimizer_name)
elif hasattr(timm.optim, optimizer_name):
optimizer_cls = getattr(optimizer_name)
else:
raise ValueError(f"{optimizer_name} optimizer is not supported")
optimizer = optimizer_cls(parameters, lr=lr, weight_decay=wd)
# build scheduler
warmup_epochs = self.hparams.warmup_epochs
warmup_factor = self.hparams.warmup_factor
decay_factor = self.hparams.decay_factor
lr_scheduler = CosineAnnealingLR(
optimizer,
T_max=self.trainer.max_epochs - warmup_epochs,
eta_min=lr * decay_factor,
)
if warmup_epochs > 0:
warmup_scheduler = LinearLR(optimizer, start_factor=warmup_factor, total_iters=warmup_epochs)
lr_scheduler = SequentialLR(
optimizer,
schedulers=[warmup_scheduler, lr_scheduler],
milestones=[warmup_epochs],
)
# https://github.com/pytorch/pytorch/issues/67318
if not hasattr(lr_scheduler, "optimizer"):
setattr(lr_scheduler, "optimizer", optimizer)
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}