forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvideo_gpuaccel_demo.py
142 lines (118 loc) · 4.61 KB
/
video_gpuaccel_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from typing import Tuple
import cv2
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.transforms import Compose
from mmengine.utils import track_iter_progress
from mmdet.apis import init_detector
from mmdet.registry import VISUALIZERS
from mmdet.structures import DetDataSample
try:
import ffmpegcv
except ImportError:
raise ImportError(
'Please install ffmpegcv with:\n\n pip install ffmpegcv')
def parse_args():
parser = argparse.ArgumentParser(
description='MMDetection video demo with GPU acceleration')
parser.add_argument('video', help='Video file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
parser.add_argument('--out', type=str, help='Output video file')
parser.add_argument('--show', action='store_true', help='Show video')
parser.add_argument(
'--nvdecode', action='store_true', help='Use NVIDIA decoder')
parser.add_argument(
'--wait-time',
type=float,
default=1,
help='The interval of show (s), 0 is block')
args = parser.parse_args()
return args
def prefetch_batch_input_shape(model: nn.Module, ori_wh: Tuple[int,
int]) -> dict:
cfg = model.cfg
w, h = ori_wh
cfg.test_dataloader.dataset.pipeline[0].type = 'LoadImageFromNDArray'
test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
data = {'img': np.zeros((h, w, 3), dtype=np.uint8), 'img_id': 0}
data = test_pipeline(data)
_, data_sample = model.data_preprocessor([data], False)
batch_input_shape = data_sample[0].batch_input_shape
return batch_input_shape
def pack_data(frame_resize: np.ndarray, batch_input_shape: Tuple[int, int],
ori_shape: Tuple[int, int]) -> dict:
assert frame_resize.shape[:2] == batch_input_shape
data_sample = DetDataSample()
data_sample.set_metainfo({
'img_shape':
batch_input_shape,
'ori_shape':
ori_shape,
'scale_factor': (batch_input_shape[0] / ori_shape[0],
batch_input_shape[1] / ori_shape[1])
})
frame_resize = torch.from_numpy(frame_resize).permute((2, 0, 1))
data = {'inputs': frame_resize, 'data_sample': data_sample}
return data
def main():
args = parse_args()
assert args.out or args.show, \
('Please specify at least one operation (save/show the '
'video) with the argument "--out" or "--show"')
model = init_detector(args.config, args.checkpoint, device=args.device)
# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
# the dataset_meta is loaded from the checkpoint and
# then pass to the model in init_detector
visualizer.dataset_meta = model.dataset_meta
if args.nvdecode:
VideoCapture = ffmpegcv.VideoCaptureNV
else:
VideoCapture = ffmpegcv.VideoCapture
video_origin = VideoCapture(args.video)
batch_input_shape = prefetch_batch_input_shape(
model, (video_origin.width, video_origin.height))
ori_shape = (video_origin.height, video_origin.width)
resize_wh = batch_input_shape[::-1]
video_resize = VideoCapture(
args.video,
resize=resize_wh,
resize_keepratio=True,
resize_keepratioalign='topleft')
video_writer = None
if args.out:
video_writer = ffmpegcv.VideoWriter(args.out, fps=video_origin.fps)
with torch.no_grad():
for i, (frame_resize, frame_origin) in enumerate(
zip(track_iter_progress(video_resize), video_origin)):
data = pack_data(frame_resize, batch_input_shape, ori_shape)
result = model.test_step([data])[0]
visualizer.add_datasample(
name='video',
image=frame_origin,
data_sample=result,
draw_gt=False,
show=False,
pred_score_thr=args.score_thr)
frame_mask = visualizer.get_image()
if args.show:
cv2.namedWindow('video', 0)
mmcv.imshow(frame_mask, 'video', args.wait_time)
if args.out:
video_writer.write(frame_mask)
if video_writer:
video_writer.release()
video_origin.release()
video_resize.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main()