forked from cleardusk/3DDFA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
benchmark.py
executable file
·118 lines (87 loc) · 3.47 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python3
# coding: utf-8
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import mobilenet_v1
import time
import numpy as np
from benchmark_aflw2000 import calc_nme as calc_nme_alfw2000
from benchmark_aflw2000 import ana as ana_alfw2000
from benchmark_aflw import calc_nme as calc_nme_alfw
from benchmark_aflw import ana as ana_aflw
from utils.ddfa import ToTensorGjz, NormalizeGjz, DDFATestDataset, reconstruct_vertex
import argparse
def extract_param(checkpoint_fp, root='', filelists=None, arch='mobilenet_1', num_classes=62, device_ids=[0],
batch_size=128, num_workers=4):
map_location = {f'cuda:{i}': 'cuda:0' for i in range(8)}
checkpoint = torch.load(checkpoint_fp, map_location=map_location)['state_dict']
torch.cuda.set_device(device_ids[0])
model = getattr(mobilenet_v1, arch)(num_classes=num_classes)
model = nn.DataParallel(model, device_ids=device_ids).cuda()
model.load_state_dict(checkpoint)
dataset = DDFATestDataset(filelists=filelists, root=root,
transform=transforms.Compose([ToTensorGjz(), NormalizeGjz(mean=127.5, std=128)]))
data_loader = data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
cudnn.benchmark = True
model.eval()
end = time.time()
outputs = []
with torch.no_grad():
for _, inputs in enumerate(data_loader):
inputs = inputs.cuda()
output = model(inputs)
for i in range(output.shape[0]):
param_prediction = output[i].cpu().numpy().flatten()
outputs.append(param_prediction)
outputs = np.array(outputs, dtype=np.float32)
print(f'Extracting params take {time.time() - end: .3f}s')
return outputs
def _benchmark_aflw(outputs):
return ana_aflw(calc_nme_alfw(outputs))
def _benchmark_aflw2000(outputs):
return ana_alfw2000(calc_nme_alfw2000(outputs))
def benchmark_alfw_params(params):
outputs = []
for i in range(params.shape[0]):
lm = reconstruct_vertex(params[i])
outputs.append(lm[:2, :])
return _benchmark_aflw(outputs)
def benchmark_aflw2000_params(params):
outputs = []
for i in range(params.shape[0]):
lm = reconstruct_vertex(params[i])
outputs.append(lm[:2, :])
return _benchmark_aflw2000(outputs)
def benchmark_pipeline(arch, checkpoint_fp):
device_ids = [0]
def aflw():
params = extract_param(
checkpoint_fp=checkpoint_fp,
root='test.data/AFLW_GT_crop',
filelists='test.data/AFLW_GT_crop.list',
arch=arch,
device_ids=device_ids,
batch_size=128)
benchmark_alfw_params(params)
def aflw2000():
params = extract_param(
checkpoint_fp=checkpoint_fp,
root='test.data/AFLW2000-3D_crop',
filelists='test.data/AFLW2000-3D_crop.list',
arch=arch,
device_ids=device_ids,
batch_size=128)
benchmark_aflw2000_params(params)
aflw2000()
aflw()
def main():
parser = argparse.ArgumentParser(description='3DDFA Benchmark')
parser.add_argument('--arch', default='mobilenet_1', type=str)
parser.add_argument('-c', '--checkpoint-fp', default='models/phase1_wpdc_vdc.pth.tar', type=str)
args = parser.parse_args()
benchmark_pipeline(args.arch, args.checkpoint_fp)
if __name__ == '__main__':
main()