-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
srgan.py
176 lines (145 loc) · 6.15 KB
/
srgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import auto_fp16
from ..builder import build_backbone, build_component, build_loss
from ..common import set_requires_grad
from ..registry import MODELS
from .basic_restorer import BasicRestorer
@MODELS.register_module()
class SRGAN(BasicRestorer):
"""SRGAN model for single image super-resolution.
Ref:
Photo-Realistic Single Image Super-Resolution Using a Generative
Adversarial Network.
Args:
generator (dict): Config for the generator.
discriminator (dict): Config for the discriminator. Default: None.
gan_loss (dict): Config for the gan loss.
Note that the loss weight in gan loss is only for the generator.
pixel_loss (dict): Config for the pixel loss. Default: None.
perceptual_loss (dict): Config for the perceptual loss. Default: None.
train_cfg (dict): Config for training. Default: None.
You may change the training of gan by setting:
`disc_steps`: how many discriminator updates after one generate
update;
`disc_init_steps`: how many discriminator updates at the start of
the training.
These two keys are useful when training with WGAN.
test_cfg (dict): Config for testing. Default: None.
pretrained (str): Path for pretrained model. Default: None.
"""
def __init__(self,
generator,
discriminator=None,
gan_loss=None,
pixel_loss=None,
perceptual_loss=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(BasicRestorer, self).__init__()
self.train_cfg = train_cfg
self.test_cfg = test_cfg
# generator
self.generator = build_backbone(generator)
# discriminator
self.discriminator = build_component(
discriminator) if discriminator else None
# support fp16
self.fp16_enabled = False
# loss
self.gan_loss = build_loss(gan_loss) if gan_loss else None
self.pixel_loss = build_loss(pixel_loss) if pixel_loss else None
self.perceptual_loss = build_loss(
perceptual_loss) if perceptual_loss else None
self.disc_steps = 1 if self.train_cfg is None else self.train_cfg.get(
'disc_steps', 1)
self.disc_init_steps = (0 if self.train_cfg is None else
self.train_cfg.get('disc_init_steps', 0))
self.step_counter = 0 # counting training steps
self.init_weights(pretrained)
def init_weights(self, pretrained=None):
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
"""
self.generator.init_weights(pretrained=pretrained)
if self.discriminator:
self.discriminator.init_weights(pretrained=pretrained)
@auto_fp16(apply_to=('lq', ))
def forward(self, lq, gt=None, test_mode=False, **kwargs):
"""Forward function.
Args:
lq (Tensor): Input lq images.
gt (Tensor): Ground-truth image. Default: None.
test_mode (bool): Whether in test mode or not. Default: False.
kwargs (dict): Other arguments.
"""
if test_mode:
return self.forward_test(lq, gt, **kwargs)
raise ValueError(
'SRGAN model does not support `forward_train` function.')
def train_step(self, data_batch, optimizer):
"""Train step.
Args:
data_batch (dict): A batch of data.
optimizer (obj): Optimizer.
Returns:
dict: Returned output.
"""
# data
lq = data_batch['lq']
gt = data_batch['gt']
# generator
fake_g_output = self.generator(lq)
losses = dict()
log_vars = dict()
# no updates to discriminator parameters.
set_requires_grad(self.discriminator, False)
if (self.step_counter % self.disc_steps == 0
and self.step_counter >= self.disc_init_steps):
if self.pixel_loss:
losses['loss_pix'] = self.pixel_loss(fake_g_output, gt)
if self.perceptual_loss:
loss_percep, loss_style = self.perceptual_loss(
fake_g_output, gt)
if loss_percep is not None:
losses['loss_perceptual'] = loss_percep
if loss_style is not None:
losses['loss_style'] = loss_style
# gan loss for generator
fake_g_pred = self.discriminator(fake_g_output)
losses['loss_gan'] = self.gan_loss(
fake_g_pred, target_is_real=True, is_disc=False)
# parse loss
loss_g, log_vars_g = self.parse_losses(losses)
log_vars.update(log_vars_g)
# optimize
optimizer['generator'].zero_grad()
loss_g.backward()
optimizer['generator'].step()
# discriminator
set_requires_grad(self.discriminator, True)
# real
real_d_pred = self.discriminator(gt)
loss_d_real = self.gan_loss(
real_d_pred, target_is_real=True, is_disc=True)
loss_d, log_vars_d = self.parse_losses(dict(loss_d_real=loss_d_real))
optimizer['discriminator'].zero_grad()
loss_d.backward()
log_vars.update(log_vars_d)
# fake
fake_d_pred = self.discriminator(fake_g_output.detach())
loss_d_fake = self.gan_loss(
fake_d_pred, target_is_real=False, is_disc=True)
loss_d, log_vars_d = self.parse_losses(dict(loss_d_fake=loss_d_fake))
loss_d.backward()
log_vars.update(log_vars_d)
optimizer['discriminator'].step()
self.step_counter += 1
log_vars.pop('loss') # remove the unnecessary 'loss'
outputs = dict(
log_vars=log_vars,
num_samples=len(gt.data),
results=dict(lq=lq.cpu(), gt=gt.cpu(), output=fake_g_output.cpu()))
return outputs