forked from boschresearch/OASIS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
92 lines (78 loc) · 3.42 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import models.losses as losses
import models.models as models
import dataloaders.dataloaders as dataloaders
import utils.utils as utils
from utils.fid_scores import fid_pytorch
import config
import time
torch.backends.cudnn.benchmark = True
#--- read options ---#
opt = config.read_arguments(train=True)
#--- create utils ---#
timer = utils.timer(opt)
visualizer_losses = utils.losses_saver(opt)
losses_computer = losses.losses_computer(opt)
dataloader, dataloader_val = dataloaders.get_dataloaders(opt)
im_saver = utils.image_saver(opt)
fid_computer = fid_pytorch(opt, dataloader_val)
#--- create models ---#
model = models.OASIS_model(opt)
model = models.put_on_multi_gpus(model, opt)
#--- create optimizers ---#
optimizerG = torch.optim.Adam(model.module.netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, opt.beta2))
optimizerD = torch.optim.Adam(model.module.netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, opt.beta2))
#--- the training loop ---#
already_started = False
start_epoch, start_iter = utils.get_start_iters(opt.loaded_latest_iter, len(dataloader))
for epoch in range(start_epoch, opt.num_epochs):
for i, data_i in enumerate(dataloader):
if not already_started and i < start_iter:
continue
already_started = True
cur_iter = epoch*len(dataloader) + i
#ta=time.time()
image, label = models.preprocess_input(opt, data_i)
#tb=time.time()
#print ("Preprocess took {}".format(tb-ta) )
#ta=time.time()
#--- generator update ---#
model.module.netG.zero_grad()
loss_G, losses_G_list = model(image, label, "losses_G", losses_computer)
loss_G, losses_G_list = loss_G.mean(), [loss.mean() if loss is not None else None for loss in losses_G_list]
loss_G.backward()
optimizerG.step()
#tb=time.time()
#print ("generator took {}".format(tb-ta) )
#ta=time.time()
#--- discriminator update ---#
model.module.netD.zero_grad()
loss_D, losses_D_list = model(image, label, "losses_D", losses_computer)
loss_D, losses_D_list = loss_D.mean(), [loss.mean() if loss is not None else None for loss in losses_D_list]
loss_D.backward()
optimizerD.step()
#tb=time.time()
#print ("discrminator took {}".format(tb-ta) )
#--- stats update ---#
if not opt.no_EMA:
utils.update_EMA(model, cur_iter, dataloader, opt)
if cur_iter % opt.freq_print == 0:
im_saver.visualize_batch(model, image, label, cur_iter)
timer(epoch, cur_iter)
if cur_iter % opt.freq_save_ckpt == 0:
utils.save_networks(opt, cur_iter, model)
if cur_iter % opt.freq_save_latest == 0:
utils.save_networks(opt, cur_iter, model, latest=True)
if cur_iter % opt.freq_fid == 0 and cur_iter > 0:
is_best = fid_computer.update(model, cur_iter)
if is_best:
utils.save_networks(opt, cur_iter, model, best=True)
visualizer_losses(cur_iter, losses_G_list+losses_D_list)
#--- after training ---#
utils.update_EMA(model, cur_iter, dataloader, opt, force_run_stats=True)
utils.save_networks(opt, cur_iter, model)
utils.save_networks(opt, cur_iter, model, latest=True)
is_best = fid_computer.update(model, cur_iter)
if is_best:
utils.save_networks(opt, cur_iter, model, best=True)
print("The training has successfully finished")