Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameterize max_det + inference default at 1000 #3215

Merged
merged 6 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def detect(opt):
pred = model(img, augment=opt.augment)[0]

# Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms, max_det=opt.max_det)
t2 = time_synchronized()

# Apply Classifier
Expand Down Expand Up @@ -153,6 +153,7 @@ def detect(opt):
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
parser.add_argument('--max-det', type=int, default=300, help='maximum number of detections per image')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
Expand Down
6 changes: 4 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,19 +215,21 @@ class NMS(nn.Module):
conf = 0.25 # confidence threshold
iou = 0.45 # IoU threshold
classes = None # (optional list) filter by class
max_det = 300 # maximum number of detections per image

def __init__(self):
super(NMS, self).__init__()

def forward(self, x):
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det)


class AutoShape(nn.Module):
# input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
conf = 0.25 # NMS confidence threshold
iou = 0.45 # NMS IoU threshold
classes = None # (optional list) filter by class
max_det = 300 # maximum number of detections per image

def __init__(self, model):
super(AutoShape, self).__init__()
Expand Down Expand Up @@ -285,7 +287,7 @@ def forward(self, imgs, size=640, augment=False, profile=False):
t.append(time_synchronized())

# Post-process
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])

Expand Down
3 changes: 1 addition & 2 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def wh_iou(wh1, wh2):


def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=()):
labels=(), max_det=300):
"""Runs Non-Maximum Suppression (NMS) on inference results

Returns:
Expand All @@ -498,7 +498,6 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non

# Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
max_det = 300 # maximum number of detections per image
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
Expand Down