-
Notifications
You must be signed in to change notification settings - Fork 631
/
main_train_gan.py
254 lines (209 loc) · 9.77 KB
/
main_train_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import os.path
import math
import argparse
import random
import numpy as np
import logging
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch
from utils import utils_logger
from utils import utils_image as util
from utils import utils_option as option
from utils.utils_dist import get_dist_info, init_dist
from data.select_dataset import define_Dataset
from models.select_model import define_Model
'''
# --------------------------------------------
# training code for GAN-based model
# --------------------------------------------
# Kai Zhang (cskaizhang@gmail.com)
# github: https://github.com/cszn/KAIR
# --------------------------------------------
'''
def main(json_path='options/train_msrresnet_gan.json'):
'''
# ----------------------------------------
# Step--1 (prepare opt)
# ----------------------------------------
'''
parser = argparse.ArgumentParser()
parser.add_argument('--opt', type=str, default=json_path, help='Path to option JSON file.')
parser.add_argument('--launcher', default='pytorch', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--dist', default=False)
opt = option.parse(parser.parse_args().opt, is_train=True)
opt['dist'] = parser.parse_args().dist
# ----------------------------------------
# distributed settings
# ----------------------------------------
if opt['dist']:
init_dist('pytorch')
opt['rank'], opt['world_size'] = get_dist_info()
if opt['rank'] == 0:
util.mkdirs((path for key, path in opt['path'].items() if 'pretrained' not in key))
# ----------------------------------------
# update opt
# ----------------------------------------
# -->-->-->-->-->-->-->-->-->-->-->-->-->-
init_iter_G, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G')
init_iter_D, init_path_D = option.find_last_checkpoint(opt['path']['models'], net_type='D')
init_iter_E, init_path_E = option.find_last_checkpoint(opt['path']['models'], net_type='E')
opt['path']['pretrained_netG'] = init_path_G
opt['path']['pretrained_netD'] = init_path_D
opt['path']['pretrained_netE'] = init_path_E
init_iter_optimizerG, init_path_optimizerG = option.find_last_checkpoint(opt['path']['models'], net_type='optimizerG')
init_iter_optimizerD, init_path_optimizerD = option.find_last_checkpoint(opt['path']['models'], net_type='optimizerD')
opt['path']['pretrained_optimizerG'] = init_path_optimizerG
opt['path']['pretrained_optimizerD'] = init_path_optimizerD
current_step = max(init_iter_G, init_iter_D, init_iter_E, init_iter_optimizerG, init_iter_optimizerD)
# opt['path']['pretrained_netG'] = ''
# current_step = 0
border = opt['scale']
# --<--<--<--<--<--<--<--<--<--<--<--<--<-
# ----------------------------------------
# save opt to a '../option.json' file
# ----------------------------------------
if opt['rank'] == 0:
option.save(opt)
# ----------------------------------------
# return None for missing key
# ----------------------------------------
opt = option.dict_to_nonedict(opt)
# ----------------------------------------
# configure logger
# ----------------------------------------
if opt['rank'] == 0:
logger_name = 'train'
utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log'))
logger = logging.getLogger(logger_name)
logger.info(option.dict2str(opt))
# ----------------------------------------
# seed
# ----------------------------------------
seed = opt['train']['manual_seed']
if seed is None:
seed = random.randint(1, 10000)
print('Random seed: {}'.format(seed))
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
'''
# ----------------------------------------
# Step--2 (creat dataloader)
# ----------------------------------------
'''
# ----------------------------------------
# 1) create_dataset
# 2) creat_dataloader for train and test
# ----------------------------------------
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train':
train_set = define_Dataset(dataset_opt)
train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size']))
if opt['rank'] == 0:
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size))
if opt['dist']:
train_sampler = DistributedSampler(train_set, shuffle=dataset_opt['dataloader_shuffle'], drop_last=True, seed=seed)
train_loader = DataLoader(train_set,
batch_size=dataset_opt['dataloader_batch_size']//opt['num_gpu'],
shuffle=False,
num_workers=dataset_opt['dataloader_num_workers']//opt['num_gpu'],
drop_last=True,
pin_memory=True,
sampler=train_sampler)
else:
train_loader = DataLoader(train_set,
batch_size=dataset_opt['dataloader_batch_size'],
shuffle=dataset_opt['dataloader_shuffle'],
num_workers=dataset_opt['dataloader_num_workers'],
drop_last=True,
pin_memory=True)
elif phase == 'test':
test_set = define_Dataset(dataset_opt)
test_loader = DataLoader(test_set, batch_size=1,
shuffle=False, num_workers=1,
drop_last=False, pin_memory=True)
else:
raise NotImplementedError("Phase [%s] is not recognized." % phase)
'''
# ----------------------------------------
# Step--3 (initialize model)
# ----------------------------------------
'''
model = define_Model(opt)
model.init_train()
if opt['rank'] == 0:
logger.info(model.info_network())
logger.info(model.info_params())
'''
# ----------------------------------------
# Step--4 (main training)
# ----------------------------------------
'''
for epoch in range(1000000): # keep running
if opt['dist']:
train_sampler.set_epoch(epoch + seed)
for i, train_data in enumerate(train_loader):
current_step += 1
# -------------------------------
# 1) update learning rate
# -------------------------------
model.update_learning_rate(current_step)
# -------------------------------
# 2) feed patch pairs
# -------------------------------
model.feed_data(train_data)
# -------------------------------
# 3) optimize parameters
# -------------------------------
model.optimize_parameters(current_step)
# -------------------------------
# 4) training information
# -------------------------------
if current_step % opt['train']['checkpoint_print'] == 0 and opt['rank'] == 0:
logs = model.current_log() # such as loss
message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(epoch, current_step, model.current_learning_rate())
for k, v in logs.items(): # merge log information into message
message += '{:s}: {:.3e} '.format(k, v)
logger.info(message)
# -------------------------------
# 5) save model
# -------------------------------
if current_step % opt['train']['checkpoint_save'] == 0 and opt['rank'] == 0:
logger.info('Saving the model.')
model.save(current_step)
# -------------------------------
# 6) testing
# -------------------------------
if current_step % opt['train']['checkpoint_test'] == 0 and opt['rank'] == 0:
avg_psnr = 0.0
idx = 0
for test_data in test_loader:
idx += 1
image_name_ext = os.path.basename(test_data['L_path'][0])
img_name, ext = os.path.splitext(image_name_ext)
img_dir = os.path.join(opt['path']['images'], img_name)
util.mkdir(img_dir)
model.feed_data(test_data)
model.test()
visuals = model.current_visuals()
E_img = util.tensor2uint(visuals['E'])
H_img = util.tensor2uint(visuals['H'])
# -----------------------
# save estimated image E
# -----------------------
save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step))
util.imsave(E_img, save_img_path)
# -----------------------
# calculate PSNR
# -----------------------
current_psnr = util.calculate_psnr(E_img, H_img, border=border)
logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(idx, image_name_ext, current_psnr))
avg_psnr += current_psnr
avg_psnr = avg_psnr / idx
# testing log
logger.info('<epoch:{:3d}, iter:{:8,d}, Average PSNR : {:<.2f}dB\n'.format(epoch, current_step, avg_psnr))
if __name__ == '__main__':
main()