Skip to content

Commit b44845f

Browse files
authored
Add unittest for infernce (open-mmlab#18)
1 parent 6071697 commit b44845f

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

tests/test_inference.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import mmcv
2+
import pytest
3+
import torch
4+
import torch.nn as nn
5+
6+
from mmaction.apis import inference_recognizer, init_recognizer
7+
8+
config_file = 'configs/recognition/tsn/tsn_r50_video_inference_1x1x3_100e_kinetics400_rgb.py' # noqa: E501
9+
label_path = 'demo/label_map.txt'
10+
video_path = 'demo/demo.mp4'
11+
12+
13+
def test_init_recognizer():
14+
with pytest.raises(TypeError):
15+
init_recognizer(dict(config_file=None))
16+
17+
if torch.cuda.is_available():
18+
device = 'cuda:0'
19+
else:
20+
device = 'cpu'
21+
22+
model = init_recognizer(config_file, None, device)
23+
24+
config = mmcv.Config.fromfile(config_file)
25+
config.model.backbone.pretrained = None
26+
27+
isinstance(model, nn.Module)
28+
if torch.cuda.is_available():
29+
assert next(model.parameters()).is_cuda is True
30+
else:
31+
assert next(model.parameters()).is_cuda is False
32+
assert model.cfg.model.backbone.pretrained is None
33+
34+
35+
def test_inference_recognizer():
36+
if torch.cuda.is_available():
37+
device = 'cuda:0'
38+
else:
39+
device = 'cpu'
40+
model = init_recognizer(config_file, None, device)
41+
42+
for ops in model.cfg.data.test.pipeline:
43+
if ops['type'] == 'TenCrop':
44+
# Use CenterCrop to reduce memory in order to pass CI
45+
ops['type'] = 'CenterCrop'
46+
47+
top5_label = inference_recognizer(model, video_path, label_path)
48+
scores = [item[1] for item in top5_label]
49+
assert len(top5_label) == 5
50+
assert scores == sorted(scores, reverse=True)

0 commit comments

Comments
 (0)