-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathtest.py
90 lines (77 loc) · 3.19 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (c) OpenMMLab. All rights reserved.
from os import path as osp
import mmcv
import torch
from mmcv.image import tensor2imgs
from mmdet3d.models import (Base3DDetector, Base3DSegmentor,
SingleStageMono3DDetector)
def single_gpu_test(model,
data_loader,
show=False,
out_dir=None,
show_score_thr=0.3):
"""Test model with single gpu.
This method tests model with single gpu and gives the 'show' option.
By setting ``show=True``, it saves the visualization results under
``out_dir``.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
show (bool, optional): Whether to save viualization results.
Default: True.
out_dir (str, optional): The path to save visualization results.
Default: None.
Returns:
list[dict]: The prediction results.
"""
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
if show:
# Visualize the results of MMDetection3D model
# 'show_results' is MMdetection3D visualization API
models_3d = (Base3DDetector, Base3DSegmentor,
SingleStageMono3DDetector)
if isinstance(model.module, models_3d):
model.module.show_results(
data,
result,
out_dir=out_dir,
show=show,
score_thr=show_score_thr)
# Visualize the results of MMDetection model
# 'show_result' is MMdetection visualization API
else:
batch_size = len(result)
if batch_size == 1 and isinstance(data['img'][0],
torch.Tensor):
img_tensor = data['img'][0]
else:
img_tensor = data['img'][0].data[0]
img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)
for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
ori_h, ori_w = img_meta['ori_shape'][:-1]
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
if out_dir:
out_file = osp.join(out_dir, img_meta['ori_filename'])
else:
out_file = None
model.module.show_result(
img_show,
result[i],
show=show,
out_file=out_file,
score_thr=show_score_thr)
results.extend(result)
batch_size = len(result)
for _ in range(batch_size):
prog_bar.update()
return results