-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_model.py
executable file
·202 lines (176 loc) · 7.26 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import copy
import json
import os
import warnings
from absl import app
from datetime import datetime
import torch
from tensorboardX import SummaryWriter
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid, save_image
from torchvision import transforms
from tqdm import trange
from diffusion import GaussianDiffusionTrainer, GaussianDiffusionSampler
from model import UNet
from helpers import *
from config import FLAGS
device = torch.device('cuda:0')
def ema(source, target, decay):
source_dict = source.state_dict()
target_dict = target.state_dict()
for key in source_dict.keys():
target_dict[key].data.copy_(
target_dict[key].data * decay +
source_dict[key].data * (1 - decay))
def warmup_lr(step):
return min(step, FLAGS.warmup) / FLAGS.warmup
def train():
# dataset
dataset = CIFAR10(
root='./data', train=True, download=True,
transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=FLAGS.batch_size, shuffle=True,
num_workers=FLAGS.num_workers, drop_last=True)
datalooper = infiniteloop(dataloader)
# model setup
net_model = UNet(
T=FLAGS.T, ch=FLAGS.ch, ch_mult=FLAGS.ch_mult, attn=FLAGS.attn,
num_res_blocks=FLAGS.num_res_blocks, dropout=FLAGS.dropout)
ckpt = torch.load(os.path.join(FLAGS.logdir, 'ckpt.pt'))
net_model.load_state_dict(ckpt['net_model'])
ema_model = copy.deepcopy(net_model)
ema_model.load_state_dict(ckpt['ema_model'])
optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)
trainer = GaussianDiffusionTrainer(
net_model, FLAGS.beta_1, FLAGS.beta_T, FLAGS.T).to(device)
net_sampler = GaussianDiffusionSampler(
net_model, FLAGS.beta_1, FLAGS.beta_T, FLAGS.T, FLAGS.img_size,
FLAGS.mean_type, FLAGS.var_type).to(device)
ema_sampler = GaussianDiffusionSampler(
ema_model, FLAGS.beta_1, FLAGS.beta_T, FLAGS.T, FLAGS.img_size,
FLAGS.mean_type, FLAGS.var_type).to(device)
if FLAGS.parallel:
trainer = torch.nn.DataParallel(trainer)
net_sampler = torch.nn.DataParallel(net_sampler)
ema_sampler = torch.nn.DataParallel(ema_sampler)
# log setup
if not os.path.exists(os.path.join(FLAGS.logdir, 'sample')):
os.makedirs(os.path.join(FLAGS.logdir, 'sample'))
x_T = torch.randn(FLAGS.sample_size, 3, FLAGS.img_size, FLAGS.img_size)
x_T = x_T.to(device)
grid = (make_grid(next(iter(dataloader))[0][:FLAGS.sample_size]) + 1) / 2
writer = SummaryWriter(FLAGS.logdir)
writer.add_image('real_sample', grid)
writer.flush()
# backup all arguments
with open(os.path.join(FLAGS.logdir, "flagfile.txt"), 'w') as f:
f.write(FLAGS.flags_into_string())
# show model size
model_size = 0
for param in net_model.parameters():
model_size += param.data.nelement()
print('Model params: %.2f M' % (model_size / 1024 / 1024))
# start training
with trange(FLAGS.total_steps, dynamic_ncols=True) as pbar:
for step in pbar:
# train
optim.zero_grad()
x_0 = next(datalooper).to(device)
loss = trainer(x_0).mean()
loss.backward()
torch.nn.utils.clip_grad_norm_(
net_model.parameters(), FLAGS.grad_clip)
optim.step()
sched.step()
ema(net_model, ema_model, FLAGS.ema_decay)
# log
writer.add_scalar('loss', loss, step)
pbar.set_postfix(loss='%.3f' % loss)
# sample
if FLAGS.sample_step > 0 and step % FLAGS.sample_step == 0:
net_model.eval()
with torch.no_grad():
x_0 = ema_sampler(x_T)
grid = (make_grid(x_0) + 1) / 2
path = os.path.join(
FLAGS.logdir, 'sample', '%d.png' % step)
save_image(grid, path)
writer.add_image('sample', grid, step)
net_model.train()
# save
if FLAGS.save_step > 0 and step % FLAGS.save_step == 0:
ckpt = {
'net_model': net_model.state_dict(),
'ema_model': ema_model.state_dict(),
'sched': sched.state_dict(),
'optim': optim.state_dict(),
'step': step,
'x_T': x_T,
}
date_str = datetime.now().strftime('%Y%m%d_%H%M%S')
torch.save(ckpt, os.path.join(FLAGS.logdir, f'ckpt_{date_str}.pt'))
# evaluate
if FLAGS.eval_step > 0 and step % FLAGS.eval_step == 0:
net_IS, net_FID, _ = evaluate(net_sampler, net_model)
ema_IS, ema_FID, _ = evaluate(ema_sampler, ema_model)
metrics = {
'IS': net_IS[0],
'IS_std': net_IS[1],
'FID': net_FID,
'IS_EMA': ema_IS[0],
'IS_std_EMA': ema_IS[1],
'FID_EMA': ema_FID
}
pbar.write(
"%d/%d " % (step, FLAGS.total_steps) +
", ".join('%s:%.3f' % (k, v) for k, v in metrics.items()))
for name, value in metrics.items():
writer.add_scalar(name, value, step)
writer.flush()
with open(os.path.join(FLAGS.logdir, 'eval.txt'), 'a') as f:
metrics['step'] = step
f.write(json.dumps(metrics) + "\n")
writer.close()
def eval():
# model setup
model = UNet(
T=FLAGS.T, ch=FLAGS.ch, ch_mult=FLAGS.ch_mult, attn=FLAGS.attn,
num_res_blocks=FLAGS.num_res_blocks, dropout=FLAGS.dropout)
sampler = GaussianDiffusionSampler(
model, FLAGS.beta_1, FLAGS.beta_T, FLAGS.T, img_size=FLAGS.img_size,
mean_type=FLAGS.mean_type, var_type=FLAGS.var_type).to(device)
if FLAGS.parallel:
sampler = torch.nn.DataParallel(sampler)
# load model and evaluate
ckpt = torch.load(os.path.join(FLAGS.logdir, 'ckpt.pt'))
model.load_state_dict(ckpt['net_model'])
(IS, IS_std), FID, samples = evaluate(sampler, model)
print("Model : IS:%6.3f(%.3f), FID:%7.3f" % (IS, IS_std, FID))
save_image(
torch.tensor(samples[:256]),
os.path.join(FLAGS.logdir, 'samples.png'),
nrow=16)
model.load_state_dict(ckpt['ema_model'])
(IS, IS_std), FID, samples = evaluate(sampler, model)
print("Model(EMA): IS:%6.3f(%.3f), FID:%7.3f" % (IS, IS_std, FID))
save_image(
torch.tensor(samples[:256]),
os.path.join(FLAGS.logdir, 'samples_ema.png'),
nrow=16)
def main(argv):
# suppress annoying inception_v3 initialization warning
warnings.simplefilter(action='ignore', category=FutureWarning)
if FLAGS.train:
train()
if FLAGS.eval:
eval()
if not FLAGS.train and not FLAGS.eval:
print('Add --train and/or --eval to execute corresponding tasks')
if __name__ == '__main__':
app.run(main)