forked from ronghuaiyang/arcface-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
120 lines (100 loc) · 4.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from __future__ import print_function
import os
from data import Dataset
import torch
from torch.utils import data
import torch.nn.functional as F
from models import *
import torchvision
from utils import Visualizer, view_model
import torch
import numpy as np
import random
import time
from config import Config
from torch.nn import DataParallel
from torch.optim.lr_scheduler import StepLR
from test import *
def save_model(model, save_path, name, iter_cnt):
save_name = os.path.join(save_path, name + '_' + str(iter_cnt) + '.pth')
torch.save(model.state_dict(), save_name)
return save_name
if __name__ == '__main__':
opt = Config()
if opt.display:
visualizer = Visualizer()
device = torch.device("cuda")
train_dataset = Dataset(opt.train_root, opt.train_list, phase='train', input_shape=opt.input_shape)
trainloader = data.DataLoader(train_dataset,
batch_size=opt.train_batch_size,
shuffle=True,
num_workers=opt.num_workers)
identity_list = get_lfw_list(opt.lfw_test_list)
img_paths = [os.path.join(opt.lfw_root, each) for each in identity_list]
print('{} train iters per epoch:'.format(len(trainloader)))
if opt.loss == 'focal_loss':
criterion = FocalLoss(gamma=2)
else:
criterion = torch.nn.CrossEntropyLoss()
if opt.backbone == 'resnet18':
model = resnet_face18(use_se=opt.use_se)
elif opt.backbone == 'resnet34':
model = resnet34()
elif opt.backbone == 'resnet50':
model = resnet50()
if opt.metric == 'add_margin':
metric_fc = AddMarginProduct(512, opt.num_classes, s=30, m=0.35)
elif opt.metric == 'arc_margin':
metric_fc = ArcMarginProduct(512, opt.num_classes, s=30, m=0.5, easy_margin=opt.easy_margin)
elif opt.metric == 'sphere':
metric_fc = SphereProduct(512, opt.num_classes, m=4)
else:
metric_fc = nn.Linear(512, opt.num_classes)
# view_model(model, opt.input_shape)
print(model)
model.to(device)
model = DataParallel(model)
metric_fc.to(device)
metric_fc = DataParallel(metric_fc)
if opt.optimizer == 'sgd':
optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': metric_fc.parameters()}],
lr=opt.lr, weight_decay=opt.weight_decay)
else:
optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': metric_fc.parameters()}],
lr=opt.lr, weight_decay=opt.weight_decay)
scheduler = StepLR(optimizer, step_size=opt.lr_step, gamma=0.1)
start = time.time()
for i in range(opt.max_epoch):
scheduler.step()
model.train()
for ii, data in enumerate(trainloader):
data_input, label = data
data_input = data_input.to(device)
label = label.to(device).long()
feature = model(data_input)
output = metric_fc(feature, label)
loss = criterion(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
iters = i * len(trainloader) + ii
if iters % opt.print_freq == 0:
output = output.data.cpu().numpy()
output = np.argmax(output, axis=1)
label = label.data.cpu().numpy()
# print(output)
# print(label)
acc = np.mean((output == label).astype(int))
speed = opt.print_freq / (time.time() - start)
time_str = time.asctime(time.localtime(time.time()))
print('{} train epoch {} iter {} {} iters/s loss {} acc {}'.format(time_str, i, ii, speed, loss.item(), acc))
if opt.display:
visualizer.display_current_results(iters, loss.item(), name='train_loss')
visualizer.display_current_results(iters, acc, name='train_acc')
start = time.time()
if i % opt.save_interval == 0 or i == opt.max_epoch:
save_model(model, opt.checkpoints_path, opt.backbone, i)
model.eval()
acc = lfw_test(model, img_paths, identity_list, opt.lfw_test_list, opt.test_batch_size)
if opt.display:
visualizer.display_current_results(iters, acc, name='test_acc')