-
Notifications
You must be signed in to change notification settings - Fork 378
/
nms_wrapper.py
78 lines (66 loc) · 2.52 KB
/
nms_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import torch
from . import nms_cuda, nms_cpu
from .soft_nms_cpu import soft_nms_cpu
def nms(dets, iou_thr, device_id=None):
"""Dispatch to either CPU or GPU NMS implementations.
The input can be either a torch tensor or numpy array. GPU NMS will be used
if the input is a gpu tensor or device_id is specified, otherwise CPU NMS
will be used. The returned type will always be the same as inputs.
Arguments:
dets (torch.Tensor or np.ndarray): bboxes with scores.
iou_thr (float): IoU threshold for NMS.
device_id (int, optional): when `dets` is a numpy array, if `device_id`
is None, then cpu nms is used, otherwise gpu_nms will be used.
Returns:
tuple: kept bboxes and indice, which is always the same data type as
the input.
"""
# convert dets (tensor or numpy array) to tensor
if isinstance(dets, torch.Tensor):
is_numpy = False
dets_th = dets
elif isinstance(dets, np.ndarray):
is_numpy = True
device = 'cpu' if device_id is None else 'cuda:{}'.format(device_id)
dets_th = torch.from_numpy(dets).to(device)
else:
raise TypeError(
'dets must be either a Tensor or numpy array, but got {}'.format(
type(dets)))
# execute cpu or cuda nms
if dets_th.shape[0] == 0:
inds = dets_th.new_zeros(0, dtype=torch.long)
else:
if dets_th.is_cuda:
inds = nms_cuda.nms(dets_th, iou_thr)
else:
inds = nms_cpu.nms(dets_th, iou_thr)
if is_numpy:
inds = inds.cpu().numpy()
return dets[inds, :], inds
def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3):
if isinstance(dets, torch.Tensor):
is_tensor = True
dets_np = dets.detach().cpu().numpy()
elif isinstance(dets, np.ndarray):
is_tensor = False
dets_np = dets
else:
raise TypeError(
'dets must be either a Tensor or numpy array, but got {}'.format(
type(dets)))
method_codes = {'linear': 1, 'gaussian': 2}
if method not in method_codes:
raise ValueError('Invalid method for SoftNMS: {}'.format(method))
new_dets, inds = soft_nms_cpu(
dets_np,
iou_thr,
method=method_codes[method],
sigma=sigma,
min_score=min_score)
if is_tensor:
return dets.new_tensor(new_dets), dets.new_tensor(
inds, dtype=torch.long)
else:
return new_dets.astype(np.float32), inds.astype(np.int64)