diff --git a/README.md b/README.md index c48f8616532..6bb4929df6a 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,7 @@ Supported methods: - [x] [DetectoRS](configs/detectors/README.md) - [x] [Generalized Focal Loss](configs/gfl/README.md) - [x] [CornerNet](configs/cornernet/README.md) +- [x] [YOLOv3](configs/yolo/README.md) Some other methods are also supported in [projects using MMDetection](./docs/projects.md). diff --git a/configs/yolo/README.md b/configs/yolo/README.md new file mode 100644 index 00000000000..36a78856d3f --- /dev/null +++ b/configs/yolo/README.md @@ -0,0 +1,25 @@ +# YOLOv3 + +## Introduction +``` +@misc{redmon2018yolov3, + title={YOLOv3: An Incremental Improvement}, + author={Joseph Redmon and Ali Farhadi}, + year={2018}, + eprint={1804.02767}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + +## Results and Models + +| Backbone | Scale | Lr schd | Mem (GB) | Inf time (fps) | box AP | Download | +| :-------------: | :-----: | :-----: | :------: | :------------: | :----: | :-------: | +| DarkNet-53 | 320 | 273e | 2.7 | 63.9 | 27.9 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection/v2.0/yolo/yolov3_d53_320_273e_coco/yolov3_d53_320_273e_coco-421362b6.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection/v2.0/yolo/yolov3_d53_320_273e_coco/yolov3_d53_320_273e_coco-20200819_172101.log.json) | +| DarkNet-53 | 416 | 273e | 3.8 | 61.2 | 30.9 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection/v2.0/yolo/yolov3_d53_mstrain-416_273e_coco/yolov3_d53_mstrain-416_273e_coco-2b60fcd9.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection/v2.0/yolo/yolov3_d53_mstrain-416_273e_coco/yolov3_d53_mstrain-416_273e_coco-20200819_173424.log.json) | +| DarkNet-53 | 608 | 273e | 7.1 | 48.1 | 33.4 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection/v2.0/yolo/yolov3_d53_mstrain-608_273e_coco/yolov3_d53_mstrain-608_273e_coco-139f5633.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection/v2.0/yolo/yolov3_d53_mstrain-608_273e_coco/yolov3_d53_mstrain-608_273e_coco-20200819_170820.log.json) | + + +## Credit +This implementation originates from the project of Haoyu Wu(@wuhy08) at Western Digital. diff --git a/configs/yolo/yolov3_d53_320_273e_coco.py b/configs/yolo/yolov3_d53_320_273e_coco.py new file mode 100644 index 00000000000..87359f6fb66 --- /dev/null +++ b/configs/yolo/yolov3_d53_320_273e_coco.py @@ -0,0 +1,42 @@ +_base_ = './yolov3_d53_mstrain-608_273e_coco.py' +# dataset settings +img_norm_cfg = dict(mean=[0, 0, 0], std=[255., 255., 255.], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile', to_float32=True), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='PhotoMetricDistortion'), + dict( + type='Expand', + mean=img_norm_cfg['mean'], + to_rgb=img_norm_cfg['to_rgb'], + ratio_range=(1, 2)), + dict( + type='MinIoURandomCrop', + min_ious=(0.4, 0.5, 0.6, 0.7, 0.8, 0.9), + min_crop_size=0.3), + dict(type='Resize', img_scale=(320, 320), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(320, 320), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/configs/yolo/yolov3_d53_mstrain-416_273e_coco.py b/configs/yolo/yolov3_d53_mstrain-416_273e_coco.py new file mode 100644 index 00000000000..d029b5cdd6b --- /dev/null +++ b/configs/yolo/yolov3_d53_mstrain-416_273e_coco.py @@ -0,0 +1,42 @@ +_base_ = './yolov3_d53_mstrain-608_273e_coco.py' +# dataset settings +img_norm_cfg = dict(mean=[0, 0, 0], std=[255., 255., 255.], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile', to_float32=True), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='PhotoMetricDistortion'), + dict( + type='Expand', + mean=img_norm_cfg['mean'], + to_rgb=img_norm_cfg['to_rgb'], + ratio_range=(1, 2)), + dict( + type='MinIoURandomCrop', + min_ious=(0.4, 0.5, 0.6, 0.7, 0.8, 0.9), + min_crop_size=0.3), + dict(type='Resize', img_scale=[(320, 320), (416, 416)], keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(416, 416), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py b/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py new file mode 100644 index 00000000000..4e18fbed943 --- /dev/null +++ b/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py @@ -0,0 +1,121 @@ +_base_ = '../_base_/default_runtime.py' +# model settings +model = dict( + type='YOLOV3', + pretrained='open-mmlab://darknet53', + backbone=dict(type='Darknet', depth=53, out_indices=(3, 4, 5)), + neck=dict( + type='YOLOV3Neck', + num_scales=3, + in_channels=[1024, 512, 256], + out_channels=[512, 256, 128]), + bbox_head=dict( + type='YOLOV3Head', + num_classes=80, + in_channels=[512, 256, 128], + out_channels=[1024, 512, 256], + anchor_generator=dict( + type='YOLOAnchorGenerator', + base_sizes=[[(116, 90), (156, 198), (373, 326)], + [(30, 61), (62, 45), (59, 119)], + [(10, 13), (16, 30), (33, 23)]], + strides=[32, 16, 8]), + bbox_coder=dict(type='YOLOBBoxCoder'), + featmap_strides=[32, 16, 8], + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0, + reduction='sum'), + loss_conf=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0, + reduction='sum'), + loss_xy=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=2.0, + reduction='sum'), + loss_wh=dict(type='MSELoss', loss_weight=2.0, reduction='sum'))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='GridAssigner', pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0)) +test_cfg = dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + conf_thr=0.005, + nms=dict(type='nms', iou_thr=0.45), + max_per_img=100) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict(mean=[0, 0, 0], std=[255., 255., 255.], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile', to_float32=True), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='PhotoMetricDistortion'), + dict( + type='Expand', + mean=img_norm_cfg['mean'], + to_rgb=img_norm_cfg['to_rgb'], + ratio_range=(1, 2)), + dict( + type='MinIoURandomCrop', + min_ious=(0.4, 0.5, 0.6, 0.7, 0.8, 0.9), + min_crop_size=0.3), + dict(type='Resize', img_scale=[(320, 320), (608, 608)], keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(608, 608), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +# optimizer +optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=2000, # same as burn-in in darknet + warmup_ratio=0.1, + step=[218, 246]) +# runtime settings +total_epochs = 273 +evaluation = dict(interval=1, metric=['bbox']) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index e2c5c809852..6730c7b07d3 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -147,6 +147,9 @@ Please refer to [Generalized Focal Loss](https://github.com/open-mmlab/mmdetecti ### CornerNet Please refer to [CornerNet](https://github.com/open-mmlab/mmdetection/blob/master/configs/cornernet) for details. +### YOLOv3 +Please refer to [YOLOv3](https://github.com/open-mmlab/mmdetection/blob/master/configs/yolo) for details. + ### Other datasets We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face). diff --git a/mmdet/core/anchor/__init__.py b/mmdet/core/anchor/__init__.py index 71c012ec765..5838ff3eefb 100644 --- a/mmdet/core/anchor/__init__.py +++ b/mmdet/core/anchor/__init__.py @@ -1,4 +1,5 @@ -from .anchor_generator import AnchorGenerator, LegacyAnchorGenerator +from .anchor_generator import (AnchorGenerator, LegacyAnchorGenerator, + YOLOAnchorGenerator) from .builder import ANCHOR_GENERATORS, build_anchor_generator from .point_generator import PointGenerator from .utils import anchor_inside_flags, calc_region, images_to_levels @@ -6,5 +7,5 @@ __all__ = [ 'AnchorGenerator', 'LegacyAnchorGenerator', 'anchor_inside_flags', 'PointGenerator', 'images_to_levels', 'calc_region', - 'build_anchor_generator', 'ANCHOR_GENERATORS' + 'build_anchor_generator', 'ANCHOR_GENERATORS', 'YOLOAnchorGenerator' ] diff --git a/mmdet/core/anchor/anchor_generator.py b/mmdet/core/anchor/anchor_generator.py index f42a48cd912..aa2c18d8b8a 100644 --- a/mmdet/core/anchor/anchor_generator.py +++ b/mmdet/core/anchor/anchor_generator.py @@ -33,7 +33,7 @@ class AnchorGenerator(object): relative to the feature grid center in multiple feature levels. By default it is set to be None and not used. If a list of tuple of float is given, they will be used to shift the centers of anchors. - center_offset (float): The offset of center in propotion to anchors' + center_offset (float): The offset of center in proportion to anchors' width and height. By default it is 0 in V2.0. Examples: @@ -215,7 +215,7 @@ def grid_anchors(self, featmap_sizes, device='cuda'): list[torch.Tensor]: Anchors in multiple feature levels. \ The sizes of each tensor should be [N, 4], where \ N = width * height * num_base_anchors, width and height \ - are the sizes of the corresponding feature lavel, \ + are the sizes of the corresponding feature level, \ num_base_anchors is the number of anchors for that level. """ assert self.num_levels == len(featmap_sizes) @@ -590,3 +590,139 @@ def __init__(self, self.centers = [((stride - 1) / 2., (stride - 1) / 2.) for stride in strides] self.base_anchors = self.gen_base_anchors() + + +@ANCHOR_GENERATORS.register_module() +class YOLOAnchorGenerator(AnchorGenerator): + """Anchor generator for YOLO. + + Args: + strides (list[int] | list[tuple[int, int]]): Strides of anchors + in multiple feature levels. + base_sizes (list[list[tuple[int, int]]]): The basic sizes + of anchors in multiple levels. + """ + + def __init__(self, strides, base_sizes): + self.strides = [_pair(stride) for stride in strides] + self.centers = [(stride[0] / 2., stride[1] / 2.) + for stride in self.strides] + self.base_sizes = [] + num_anchor_per_level = len(base_sizes[0]) + for base_sizes_per_level in base_sizes: + assert num_anchor_per_level == len(base_sizes_per_level) + self.base_sizes.append( + [_pair(base_size) for base_size in base_sizes_per_level]) + self.base_anchors = self.gen_base_anchors() + + @property + def num_levels(self): + """int: number of feature levels that the generator will be applied""" + return len(self.base_sizes) + + def gen_base_anchors(self): + """Generate base anchors. + + Returns: + list(torch.Tensor): Base anchors of a feature grid in multiple \ + feature levels. + """ + multi_level_base_anchors = [] + for i, base_sizes_per_level in enumerate(self.base_sizes): + center = None + if self.centers is not None: + center = self.centers[i] + multi_level_base_anchors.append( + self.gen_single_level_base_anchors(base_sizes_per_level, + center)) + return multi_level_base_anchors + + def gen_single_level_base_anchors(self, base_sizes_per_level, center=None): + """Generate base anchors of a single level. + + Args: + base_sizes_per_level (list[tuple[int, int]]): Basic sizes of + anchors. + center (tuple[float], optional): The center of the base anchor + related to a single feature grid. Defaults to None. + + Returns: + torch.Tensor: Anchors in a single-level feature maps. + """ + x_center, y_center = center + base_anchors = [] + for base_size in base_sizes_per_level: + w, h = base_size + + # use float anchor and the anchor's center is aligned with the + # pixel center + base_anchor = torch.Tensor([ + x_center - 0.5 * w, y_center - 0.5 * h, x_center + 0.5 * w, + y_center + 0.5 * h + ]) + base_anchors.append(base_anchor) + base_anchors = torch.stack(base_anchors, dim=0) + + return base_anchors + + def responsible_flags(self, featmap_sizes, gt_bboxes, device='cuda'): + """Generate responsible anchor flags of grid cells in multiple scales. + + Args: + featmap_sizes (list(tuple)): List of feature map sizes in multiple + feature levels. + gt_bboxes (Tensor): Ground truth boxes, shape (n, 4). + device (str): Device where the anchors will be put on. + + Return: + list(torch.Tensor): responsible flags of anchors in multiple level + """ + assert self.num_levels == len(featmap_sizes) + multi_level_responsible_flags = [] + for i in range(self.num_levels): + anchor_stride = self.strides[i] + flags = self.single_level_responsible_flags( + featmap_sizes[i], + gt_bboxes, + anchor_stride, + self.num_base_anchors[i], + device=device) + multi_level_responsible_flags.append(flags) + return multi_level_responsible_flags + + def single_level_responsible_flags(self, + featmap_size, + gt_bboxes, + stride, + num_base_anchors, + device='cuda'): + """Generate the responsible flags of anchor in a single feature map. + + Args: + featmap_size (tuple[int]): The size of feature maps. + gt_bboxes (Tensor): Ground truth boxes, shape (n, 4). + stride (tuple(int)): stride of current level + num_base_anchors (int): The number of base anchors. + device (str, optional): Device where the flags will be put on. + Defaults to 'cuda'. + + Returns: + torch.Tensor: The valid flags of each anchor in a single level \ + feature map. + """ + feat_h, feat_w = featmap_size + gt_bboxes_cx = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) * 0.5).to(device) + gt_bboxes_cy = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) * 0.5).to(device) + gt_bboxes_grid_x = torch.floor(gt_bboxes_cx / stride[0]).long() + gt_bboxes_grid_y = torch.floor(gt_bboxes_cy / stride[1]).long() + + # row major indexing + gt_bboxes_grid_idx = gt_bboxes_grid_y * feat_w + gt_bboxes_grid_x + + responsible_grid = torch.zeros( + feat_h * feat_w, dtype=torch.uint8, device=device) + responsible_grid[gt_bboxes_grid_idx] = 1 + + responsible_grid = responsible_grid[:, None].expand( + responsible_grid.size(0), num_base_anchors).contiguous().view(-1) + return responsible_grid diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py index f72306b88c5..9ba7e1adcf6 100644 --- a/mmdet/core/bbox/assigners/__init__.py +++ b/mmdet/core/bbox/assigners/__init__.py @@ -3,10 +3,11 @@ from .atss_assigner import ATSSAssigner from .base_assigner import BaseAssigner from .center_region_assigner import CenterRegionAssigner +from .grid_assigner import GridAssigner from .max_iou_assigner import MaxIoUAssigner from .point_assigner import PointAssigner __all__ = [ 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', - 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner' + 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner' ] diff --git a/mmdet/core/bbox/assigners/grid_assigner.py b/mmdet/core/bbox/assigners/grid_assigner.py new file mode 100644 index 00000000000..7390ea63706 --- /dev/null +++ b/mmdet/core/bbox/assigners/grid_assigner.py @@ -0,0 +1,155 @@ +import torch + +from ..builder import BBOX_ASSIGNERS +from ..iou_calculators import build_iou_calculator +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + + +@BBOX_ASSIGNERS.register_module() +class GridAssigner(BaseAssigner): + """Assign a corresponding gt bbox or background to each bbox. + + Each proposals will be assigned with `-1`, `0`, or a positive integer + indicating the ground truth index. + + - -1: don't care + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + pos_iou_thr (float): IoU threshold for positive bboxes. + neg_iou_thr (float or tuple): IoU threshold for negative bboxes. + min_pos_iou (float): Minimum iou for a bbox to be considered as a + positive bbox. Positive samples can have smaller IoU than + pos_iou_thr due to the 4th step (assign max IoU sample to each gt). + gt_max_assign_all (bool): Whether to assign all bboxes with the same + highest overlap with some gt to that gt. + """ + + def __init__(self, + pos_iou_thr, + neg_iou_thr, + min_pos_iou=.0, + gt_max_assign_all=True, + iou_calculator=dict(type='BboxOverlaps2D')): + self.pos_iou_thr = pos_iou_thr + self.neg_iou_thr = neg_iou_thr + self.min_pos_iou = min_pos_iou + self.gt_max_assign_all = gt_max_assign_all + self.iou_calculator = build_iou_calculator(iou_calculator) + + def assign(self, bboxes, box_responsible_flags, gt_bboxes, gt_labels=None): + """Assign gt to bboxes. The process is very much like the max iou + assigner, except that positive samples are constrained within the cell + that the gt boxes fell in. + + This method assign a gt bbox to every bbox (proposal/anchor), each bbox + will be assigned with -1, 0, or a positive number. -1 means don't care, + 0 means negative sample, positive number is the index (1-based) of + assigned gt. + The assignment is done in following steps, the order matters. + + 1. assign every bbox to -1 + 2. assign proposals whose iou with all gts <= neg_iou_thr to 0 + 3. for each bbox within a cell, if the iou with its nearest gt > + pos_iou_thr and the center of that gt falls inside the cell, + assign it to that bbox + 4. for each gt bbox, assign its nearest proposals within the cell the + gt bbox falls in to itself. + + Args: + bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4). + box_responsible_flags (Tensor): flag to indicate whether box is + responsible for prediction, shape(n, ) + gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). + gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). + + Returns: + :obj:`AssignResult`: The assign result. + """ + num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0) + + # compute iou between all gt and bboxes + overlaps = self.iou_calculator(gt_bboxes, bboxes) + + # 1. assign -1 by default + assigned_gt_inds = overlaps.new_full((num_bboxes, ), + -1, + dtype=torch.long) + + if num_gts == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + max_overlaps = overlaps.new_zeros((num_bboxes, )) + if num_gts == 0: + # No truth, assign everything to background + assigned_gt_inds[:] = 0 + if gt_labels is None: + assigned_labels = None + else: + assigned_labels = overlaps.new_full((num_bboxes, ), + -1, + dtype=torch.long) + return AssignResult( + num_gts, + assigned_gt_inds, + max_overlaps, + labels=assigned_labels) + + # 2. assign negative: below + # for each anchor, which gt best overlaps with it + # for each anchor, the max iou of all gts + # shape of max_overlaps == argmax_overlaps == num_bboxes + max_overlaps, argmax_overlaps = overlaps.max(dim=0) + + if isinstance(self.neg_iou_thr, float): + assigned_gt_inds[(max_overlaps >= 0) + & (max_overlaps <= self.neg_iou_thr)] = 0 + elif isinstance(self.neg_iou_thr, (tuple, list)): + assert len(self.neg_iou_thr) == 2 + assigned_gt_inds[(max_overlaps > self.neg_iou_thr[0]) + & (max_overlaps <= self.neg_iou_thr[1])] = 0 + + # 3. assign positive: falls into responsible cell and above + # positive IOU threshold, the order matters. + # the prior condition of comparision is to filter out all + # unrelated anchors, i.e. not box_responsible_flags + overlaps[:, ~box_responsible_flags.type(torch.bool)] = -1. + + # calculate max_overlaps again, but this time we only consider IOUs + # for anchors responsible for prediction + max_overlaps, argmax_overlaps = overlaps.max(dim=0) + + # for each gt, which anchor best overlaps with it + # for each gt, the max iou of all proposals + # shape of gt_max_overlaps == gt_argmax_overlaps == num_gts + gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1) + + pos_inds = (max_overlaps > + self.pos_iou_thr) & box_responsible_flags.type(torch.bool) + assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1 + + # 4. assign positive to max overlapped anchors within responsible cell + for i in range(num_gts): + if gt_max_overlaps[i] > self.min_pos_iou: + if self.gt_max_assign_all: + max_iou_inds = (overlaps[i, :] == gt_max_overlaps[i]) & \ + box_responsible_flags.type(torch.bool) + assigned_gt_inds[max_iou_inds] = i + 1 + elif box_responsible_flags[gt_argmax_overlaps[i]]: + assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1 + + # assign labels of positive anchors + if gt_labels is not None: + assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) + pos_inds = torch.nonzero( + assigned_gt_inds > 0, as_tuple=False).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[ + assigned_gt_inds[pos_inds] - 1] + + else: + assigned_labels = None + + return AssignResult( + num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels) diff --git a/mmdet/core/bbox/assigners/max_iou_assigner.py b/mmdet/core/bbox/assigners/max_iou_assigner.py index 6e3e54af14c..a99f77e104b 100644 --- a/mmdet/core/bbox/assigners/max_iou_assigner.py +++ b/mmdet/core/bbox/assigners/max_iou_assigner.py @@ -31,7 +31,7 @@ class MaxIoUAssigner(BaseAssigner): `bboxes` and `gt_bboxes_ignore`, or the contrary. match_low_quality (bool): Whether to allow low quality matches. This is usually allowed for RPN and single stage detectors, but not allowed - in the second stage. Details are demonetrated in Step 4. + in the second stage. Details are demonstrated in Step 4. gpu_assign_thr (int): The upper bound of the number of GT for GPU assign. When the number of gt is above this threshold, will assign on CPU device. Negative values mean not assign on CPU. diff --git a/mmdet/core/bbox/coder/__init__.py b/mmdet/core/bbox/coder/__init__.py index b8ebc369091..953ac5a849e 100644 --- a/mmdet/core/bbox/coder/__init__.py +++ b/mmdet/core/bbox/coder/__init__.py @@ -3,8 +3,9 @@ from .legacy_delta_xywh_bbox_coder import LegacyDeltaXYWHBBoxCoder from .pseudo_bbox_coder import PseudoBBoxCoder from .tblr_bbox_coder import TBLRBBoxCoder +from .yolo_bbox_coder import YOLOBBoxCoder __all__ = [ 'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder', - 'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder' + 'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder' ] diff --git a/mmdet/core/bbox/coder/yolo_bbox_coder.py b/mmdet/core/bbox/coder/yolo_bbox_coder.py new file mode 100644 index 00000000000..2a1dc34fcd3 --- /dev/null +++ b/mmdet/core/bbox/coder/yolo_bbox_coder.py @@ -0,0 +1,86 @@ +import torch + +from ..builder import BBOX_CODERS +from .base_bbox_coder import BaseBBoxCoder + + +@BBOX_CODERS.register_module() +class YOLOBBoxCoder(BaseBBoxCoder): + """YOLO BBox coder. + + Following `YOLO `_, this coder divide + image into grids, and encode bbox (x1, y1, x2, y2) into (cx, cy, dw, dh). + cx, cy in [0., 1.], denotes relative center position w.r.t the center of + bboxes. dw, dh are the same as :obj:`DeltaXYWHBBoxCoder`. + + Args: + eps (float): Min value of cx, cy when encoding. + """ + + def __init__(self, eps=1e-6): + super(BaseBBoxCoder, self).__init__() + self.eps = eps + + def encode(self, bboxes, gt_bboxes, stride): + """Get box regression transformation deltas that can be used to + transform the ``bboxes`` into the ``gt_bboxes``. + + Args: + bboxes (torch.Tensor): Source boxes, e.g., anchors. + gt_bboxes (torch.Tensor): Target of the transformation, e.g., + ground-truth boxes. + stride (torch.Tensor | int): Stride of bboxes. + + Returns: + torch.Tensor: Box transformation deltas + """ + + assert bboxes.size(0) == gt_bboxes.size(0) + assert bboxes.size(-1) == gt_bboxes.size(-1) == 4 + x_center_gt = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) * 0.5 + y_center_gt = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) * 0.5 + w_gt = gt_bboxes[..., 2] - gt_bboxes[..., 0] + h_gt = gt_bboxes[..., 3] - gt_bboxes[..., 1] + x_center = (bboxes[..., 0] + bboxes[..., 2]) * 0.5 + y_center = (bboxes[..., 1] + bboxes[..., 3]) * 0.5 + w = bboxes[..., 2] - bboxes[..., 0] + h = bboxes[..., 3] - bboxes[..., 1] + w_target = torch.log((w_gt / w).clamp(min=self.eps)) + h_target = torch.log((h_gt / h).clamp(min=self.eps)) + x_center_target = ((x_center_gt - x_center) / stride + 0.5).clamp( + self.eps, 1 - self.eps) + y_center_target = ((y_center_gt - y_center) / stride + 0.5).clamp( + self.eps, 1 - self.eps) + encoded_bboxes = torch.stack( + [x_center_target, y_center_target, w_target, h_target], dim=-1) + return encoded_bboxes + + def decode(self, bboxes, pred_bboxes, stride): + """Apply transformation `pred_bboxes` to `boxes`. + + Args: + boxes (torch.Tensor): Basic boxes, e.g. anchors. + pred_bboxes (torch.Tensor): Encoded boxes with shape + stride (torch.Tensor | int): Strides of bboxes. + + Returns: + torch.Tensor: Decoded boxes. + """ + assert pred_bboxes.size(0) == bboxes.size(0) + assert pred_bboxes.size(-1) == bboxes.size(-1) == 4 + x_center = (bboxes[..., 0] + bboxes[..., 2]) * 0.5 + y_center = (bboxes[..., 1] + bboxes[..., 3]) * 0.5 + w = bboxes[..., 2] - bboxes[..., 0] + h = bboxes[..., 3] - bboxes[..., 1] + # Get outputs x, y + x_center_pred = (pred_bboxes[..., 0] - 0.5) * stride + x_center + y_center_pred = (pred_bboxes[..., 1] - 0.5) * stride + y_center + w_pred = torch.exp(pred_bboxes[..., 2]) * w + h_pred = torch.exp(pred_bboxes[..., 3]) * h + + decoded_bboxes = torch.stack( + (x_center_pred - w_pred / 2, y_center_pred - h_pred / 2, + x_center_pred + w_pred / 2, y_center_pred + h_pred / 2), + dim=-1) + + return decoded_bboxes diff --git a/mmdet/models/backbones/__init__.py b/mmdet/models/backbones/__init__.py index 6d4cc1db79e..8a306c2ddb1 100644 --- a/mmdet/models/backbones/__init__.py +++ b/mmdet/models/backbones/__init__.py @@ -1,3 +1,4 @@ +from .darknet import Darknet from .detectors_resnet import DetectoRS_ResNet from .detectors_resnext import DetectoRS_ResNeXt from .hourglass import HourglassNet @@ -10,5 +11,5 @@ __all__ = [ 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net', - 'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt' + 'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet' ] diff --git a/mmdet/models/backbones/darknet.py b/mmdet/models/backbones/darknet.py new file mode 100644 index 00000000000..517fe262592 --- /dev/null +++ b/mmdet/models/backbones/darknet.py @@ -0,0 +1,199 @@ +# Copyright (c) 2019 Western Digital Corporation or its affiliates. + +import logging + +import torch.nn as nn +from mmcv.cnn import ConvModule, constant_init, kaiming_init +from mmcv.runner import load_checkpoint +from torch.nn.modules.batchnorm import _BatchNorm + +from ..builder import BACKBONES + + +class ResBlock(nn.Module): + """The basic residual block used in Darknet. Each ResBlock consists of two + ConvModules and the input is added to the final output. Each ConvModule is + composed of Conv, BN, and LeakyReLU. In YoloV3 paper, the first convLayer + has half of the number of the filters as much as the second convLayer. The + first convLayer has filter size of 1x1 and the second one has the filter + size of 3x3. + + Args: + in_channels (int): The input channels. Must be even. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True) + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + """ + + def __init__(self, + in_channels, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', negative_slope=0.1)): + super(ResBlock, self).__init__() + assert in_channels % 2 == 0 # ensure the in_channels is even + half_in_channels = in_channels // 2 + + # shortcut + cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) + + self.conv1 = ConvModule(in_channels, half_in_channels, 1, **cfg) + self.conv2 = ConvModule( + half_in_channels, in_channels, 3, padding=1, **cfg) + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.conv2(out) + out = out + residual + + return out + + +@BACKBONES.register_module() +class Darknet(nn.Module): + """Darknet backbone. + + Args: + depth (int): Depth of Darknet. Currently only support 53. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True) + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + + Example: + >>> from mmdet.models import Darknet + >>> import torch + >>> self = Darknet(depth=53) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 416, 416) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + ... + (1, 256, 52, 52) + (1, 512, 26, 26) + (1, 1024, 13, 13) + """ + + # Dict(depth: (layers, channels)) + arch_settings = { + 53: ((1, 2, 8, 8, 4), ((32, 64), (64, 128), (128, 256), (256, 512), + (512, 1024))) + } + + def __init__(self, + depth=53, + out_indices=(3, 4, 5), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', negative_slope=0.1), + norm_eval=True): + super(Darknet, self).__init__() + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for darknet') + self.depth = depth + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.layers, self.channels = self.arch_settings[depth] + + cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) + + self.conv1 = ConvModule(3, 32, 3, padding=1, **cfg) + + self.cr_blocks = ['conv1'] + for i, n_layers in enumerate(self.layers): + layer_name = f'conv_res_block{i + 1}' + in_c, out_c = self.channels[i] + self.add_module( + layer_name, + self.make_conv_res_block(in_c, out_c, n_layers, **cfg)) + self.cr_blocks.append(layer_name) + + self.norm_eval = norm_eval + + def forward(self, x): + outs = [] + for i, layer_name in enumerate(self.cr_blocks): + cr_block = getattr(self, layer_name) + x = cr_block(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + + else: + raise TypeError('pretrained must be a str or None') + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for i in range(self.frozen_stages): + m = getattr(self, self.cr_blocks[i]) + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(Darknet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + @staticmethod + def make_conv_res_block(in_channels, + out_channels, + res_repeat, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', + negative_slope=0.1)): + """In Darknet backbone, ConvLayer is usually followed by ResBlock. This + function will make that. The Conv layers always have 3x3 filters with + stride=2. The number of the filters in Conv layer is the same as the + out channels of the ResBlock. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + res_repeat (int): The number of ResBlocks. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True) + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + """ + + cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) + + model = nn.Sequential() + model.add_module( + 'conv', + ConvModule( + in_channels, out_channels, 3, stride=2, padding=1, **cfg)) + for idx in range(res_repeat): + model.add_module('res{}'.format(idx), + ResBlock(out_channels, **cfg)) + return model diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index 9f5ec181ac6..3c6405b803f 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -19,11 +19,13 @@ from .retina_sepbn_head import RetinaSepBNHead from .rpn_head import RPNHead from .ssd_head import SSDHead +from .yolo_head import YOLOV3Head __all__ = [ 'AnchorFreeHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead', 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead', - 'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'PAAHead' + 'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'PAAHead', + 'YOLOV3Head' ] diff --git a/mmdet/models/dense_heads/yolo_head.py b/mmdet/models/dense_heads/yolo_head.py new file mode 100644 index 00000000000..f93d620e39c --- /dev/null +++ b/mmdet/models/dense_heads/yolo_head.py @@ -0,0 +1,489 @@ +# Copyright (c) 2019 Western Digital Corporation or its affiliates. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, normal_init + +from mmdet.core import (build_anchor_generator, build_assigner, + build_bbox_coder, build_sampler, force_fp32, + images_to_levels, multi_apply, multiclass_nms) +from ..builder import HEADS, build_loss +from .base_dense_head import BaseDenseHead + + +@HEADS.register_module() +class YOLOV3Head(BaseDenseHead): + """YOLOV3Head Paper link: https://arxiv.org/abs/1804.02767. + + Args: + num_classes (int): The number of object classes (w/o background) + in_channels (List[int]): Number of input channels per scale. + out_channels (List[int]): The number of output channels per scale + before the final 1x1 layer. Default: (1024, 512, 256). + anchor_generator (dict): Config dict for anchor generator + bbox_coder (dict): Config of bounding box coder. + featmap_strides (List[int]): The stride of each scale. + Should be in descending order. Default: (32, 16, 8). + one_hot_smoother (float): Set a non-zero value to enable label-smooth + Default: 0. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True) + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + loss_cls (dict): Config of classification loss. + loss_conf (dict): Config of confidence loss. + loss_xy (dict): Config of xy coordinate loss. + loss_wh (dict): Config of wh coordinate loss. + train_cfg (dict): Training config of YOLOV3 head. Default: None. + test_cfg (dict): Testing config of YOLOV3 head. Default: None. + """ + + def __init__(self, + num_classes, + in_channels, + out_channels=(1024, 512, 256), + anchor_generator=dict( + type='YOLOAnchorGenerator', + base_sizes=[[(116, 90), (156, 198), (373, 326)], + [(30, 61), (62, 45), (59, 119)], + [(10, 13), (16, 30), (33, 23)]], + strides=[32, 16, 8]), + bbox_coder=dict(type='YOLOBBoxCoder'), + featmap_strides=[32, 16, 8], + one_hot_smoother=0., + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', negative_slope=0.1), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + loss_conf=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + loss_xy=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + loss_wh=dict(type='MSELoss', loss_weight=1.0), + train_cfg=None, + test_cfg=None): + super(YOLOV3Head, self).__init__() + # Check params + assert (len(in_channels) == len(out_channels) == len(featmap_strides)) + + self.num_classes = num_classes + self.in_channels = in_channels + self.out_channels = out_channels + self.featmap_strides = featmap_strides + self.train_cfg = train_cfg + self.test_cfg = test_cfg + if self.train_cfg: + self.assigner = build_assigner(self.train_cfg.assigner) + if hasattr(self.train_cfg, 'sampler'): + sampler_cfg = self.train_cfg.sampler + else: + sampler_cfg = dict(type='PseudoSampler') + self.sampler = build_sampler(sampler_cfg, context=self) + + self.one_hot_smoother = one_hot_smoother + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.bbox_coder = build_bbox_coder(bbox_coder) + self.anchor_generator = build_anchor_generator(anchor_generator) + + self.loss_cls = build_loss(loss_cls) + self.loss_conf = build_loss(loss_conf) + self.loss_xy = build_loss(loss_xy) + self.loss_wh = build_loss(loss_wh) + # usually the numbers of anchors for each level are the same + # except SSD detectors + self.num_anchors = self.anchor_generator.num_base_anchors[0] + assert len( + self.anchor_generator.num_base_anchors) == len(featmap_strides) + self._init_layers() + + @property + def num_levels(self): + return len(self.featmap_strides) + + @property + def num_attrib(self): + """int: number of attributes in pred_map, bboxes (4) + + objectness (1) + num_classes""" + + return 5 + self.num_classes + + def _init_layers(self): + self.convs_bridge = nn.ModuleList() + self.convs_pred = nn.ModuleList() + for i in range(self.num_levels): + conv_bridge = ConvModule( + self.in_channels[i], + self.out_channels[i], + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + conv_pred = nn.Conv2d(self.out_channels[i], + self.num_anchors * self.num_attrib, 1) + + self.convs_bridge.append(conv_bridge) + self.convs_pred.append(conv_pred) + + def init_weights(self): + """Initialize weights of the head.""" + for m in self.convs_pred: + normal_init(m, std=0.01) + + def forward(self, feats): + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple[Tensor]: A tuple of multi-level predication map, each is a + 4D-tensor of shape (batch_size, 5+num_classes, height, width). + """ + + assert len(feats) == self.num_levels + pred_maps = [] + for i in range(self.num_levels): + x = feats[i] + x = self.convs_bridge[i](x) + pred_map = self.convs_pred[i](x) + pred_maps.append(pred_map) + + return tuple(pred_maps), + + @force_fp32(apply_to=('pred_maps', )) + def get_bboxes(self, pred_maps, img_metas, cfg=None, rescale=False): + """Transform network output for a batch into bbox predictions. + + Args: + pred_maps (list[Tensor]): Raw predictions for a batch of images. + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used + rescale (bool): If True, return boxes in original image space + + Returns: + list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. + The first item is an (n, 5) tensor, where the first 4 columns + are bounding box positions (tl_x, tl_y, br_x, br_y) and the + 5-th column is a score between 0 and 1. The second item is a + (n,) tensor where each item is the predicted class label of the + corresponding box. + """ + result_list = [] + num_levels = len(pred_maps) + for img_id in range(len(img_metas)): + pred_maps_list = [ + pred_maps[i][img_id].detach() for i in range(num_levels) + ] + scale_factor = img_metas[img_id]['scale_factor'] + proposals = self._get_bboxes_single(pred_maps_list, scale_factor, + cfg, rescale) + result_list.append(proposals) + return result_list + + def _get_bboxes_single(self, + pred_maps_list, + scale_factor, + cfg, + rescale=False): + """Transform outputs for a single batch item into bbox predictions. + + Args: + pred_maps_list (list[Tensor]): Prediction maps for different scales + of each single image in the batch. + scale_factor (ndarray): Scale factor of the image arrange as + (w_scale, h_scale, w_scale, h_scale). + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + + Returns: + Tensor: Labeled boxes in shape (n, 5), where the first 4 columns + are bounding box positions (tl_x, tl_y, br_x, br_y) and the + 5-th column is a score between 0 and 1. + """ + cfg = self.test_cfg if cfg is None else cfg + assert len(pred_maps_list) == self.num_levels + multi_lvl_bboxes = [] + multi_lvl_cls_scores = [] + multi_lvl_conf_scores = [] + num_levels = len(pred_maps_list) + featmap_sizes = [ + pred_maps_list[i].shape[-2:] for i in range(num_levels) + ] + multi_lvl_anchors = self.anchor_generator.grid_anchors( + featmap_sizes, pred_maps_list[0][0].device) + for i in range(self.num_levels): + # get some key info for current scale + pred_map = pred_maps_list[i] + stride = self.featmap_strides[i] + + # (h, w, num_anchors*num_attrib) -> (h*w*num_anchors, num_attrib) + pred_map = pred_map.permute(1, 2, 0).reshape(-1, self.num_attrib) + + pred_map[..., :2] = torch.sigmoid(pred_map[..., :2]) + bbox_pred = self.bbox_coder.decode(multi_lvl_anchors[i], + pred_map[..., :4], stride) + # conf and cls + conf_pred = torch.sigmoid(pred_map[..., 4]).view(-1) + cls_pred = torch.sigmoid(pred_map[..., 5:]).view( + -1, self.num_classes) # Cls pred one-hot. + + # Filtering out all predictions with conf < conf_thr + conf_thr = cfg.get('conf_thr', -1) + conf_inds = conf_pred.ge(conf_thr).nonzero().flatten() + bbox_pred = bbox_pred[conf_inds, :] + cls_pred = cls_pred[conf_inds, :] + conf_pred = conf_pred[conf_inds] + + # Get top-k prediction + nms_pre = cfg.get('nms_pre', -1) + if 0 < nms_pre < conf_pred.size(0): + _, topk_inds = conf_pred.topk(nms_pre) + bbox_pred = bbox_pred[topk_inds, :] + cls_pred = cls_pred[topk_inds, :] + conf_pred = conf_pred[topk_inds] + + # Save the result of current scale + multi_lvl_bboxes.append(bbox_pred) + multi_lvl_cls_scores.append(cls_pred) + multi_lvl_conf_scores.append(conf_pred) + + # Merge the results of different scales together + multi_lvl_bboxes = torch.cat(multi_lvl_bboxes) + multi_lvl_cls_scores = torch.cat(multi_lvl_cls_scores) + multi_lvl_conf_scores = torch.cat(multi_lvl_conf_scores) + + if multi_lvl_conf_scores.size(0) == 0: + return torch.zeros((0, 5)), torch.zeros((0, )) + + if rescale: + multi_lvl_bboxes /= multi_lvl_bboxes.new_tensor(scale_factor) + + # In mmdet 2.x, the class_id for background is num_classes. + # i.e., the last column. + padding = multi_lvl_cls_scores.new_zeros(multi_lvl_cls_scores.shape[0], + 1) + multi_lvl_cls_scores = torch.cat([multi_lvl_cls_scores, padding], + dim=1) + + det_bboxes, det_labels = multiclass_nms( + multi_lvl_bboxes, + multi_lvl_cls_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + score_factors=multi_lvl_conf_scores) + + return det_bboxes, det_labels + + @force_fp32(apply_to=('pred_maps', )) + def loss(self, + pred_maps, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None): + """Compute loss of the head. + + Args: + pred_maps (list[Tensor]): Prediction map for each scale level, + shape (N, num_anchors * num_attrib, H, W) + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (None | list[Tensor]): specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_imgs = len(img_metas) + device = pred_maps[0][0].device + + featmap_sizes = [ + pred_maps[i].shape[-2:] for i in range(self.num_levels) + ] + multi_level_anchors = self.anchor_generator.grid_anchors( + featmap_sizes, device) + anchor_list = [multi_level_anchors for _ in range(num_imgs)] + + responsible_flag_list = [] + for img_id in range(len(img_metas)): + responsible_flag_list.append( + self.anchor_generator.responsible_flags( + featmap_sizes, gt_bboxes[img_id], device)) + + target_maps_list, neg_maps_list = self.get_targets( + anchor_list, responsible_flag_list, gt_bboxes, gt_labels) + + losses_cls, losses_conf, losses_xy, losses_wh = multi_apply( + self.loss_single, pred_maps, target_maps_list, neg_maps_list) + + return dict( + loss_cls=losses_cls, + loss_conf=losses_conf, + loss_xy=losses_xy, + loss_wh=losses_wh) + + def loss_single(self, pred_map, target_map, neg_map): + """Compute loss of a single image from a batch. + + Args: + pred_map (Tensor): Raw predictions for a single level. + target_map (Tensor): The Ground-Truth target for a single level. + neg_map (Tensor): The negative masks for a single level. + + Returns: + tuple: + loss_cls (Tensor): Classification loss. + loss_conf (Tensor): Confidence loss. + loss_xy (Tensor): Regression loss of x, y coordinate. + loss_wh (Tensor): Regression loss of w, h coordinate. + """ + + num_imgs = len(pred_map) + pred_map = pred_map.permute(0, 2, 3, + 1).reshape(num_imgs, -1, self.num_attrib) + neg_mask = neg_map.float() + pos_mask = target_map[..., 4] + pos_and_neg_mask = neg_mask + pos_mask + pos_mask = pos_mask.unsqueeze(dim=-1) + if torch.max(pos_and_neg_mask) > 1.: + warnings.warn('There is overlap between pos and neg sample.') + pos_and_neg_mask = pos_and_neg_mask.clamp(min=0., max=1.) + + pred_xy = pred_map[..., :2] + pred_wh = pred_map[..., 2:4] + pred_conf = pred_map[..., 4] + pred_label = pred_map[..., 5:] + + target_xy = target_map[..., :2] + target_wh = target_map[..., 2:4] + target_conf = target_map[..., 4] + target_label = target_map[..., 5:] + + loss_cls = self.loss_cls(pred_label, target_label, weight=pos_mask) + loss_conf = self.loss_conf( + pred_conf, target_conf, weight=pos_and_neg_mask) + loss_xy = self.loss_xy(pred_xy, target_xy, weight=pos_mask) + loss_wh = self.loss_wh(pred_wh, target_wh, weight=pos_mask) + + return loss_cls, loss_conf, loss_xy, loss_wh + + def get_targets(self, anchor_list, responsible_flag_list, gt_bboxes_list, + gt_labels_list): + """Compute target maps for anchors in multiple images. + + Args: + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_total_anchors, 4). + responsible_flag_list (list[list[Tensor]]): Multi level responsible + flags of each image. Each element is a tensor of shape + (num_total_anchors, ) + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + gt_labels_list (list[Tensor]): Ground truth labels of each box. + + Returns: + tuple: Usually returns a tuple containing learning targets. + - target_map_list (list[Tensor]): Target map of each level. + - neg_map_list (list[Tensor]): Negative map of each level. + """ + num_imgs = len(anchor_list) + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + + results = multi_apply(self._get_targets_single, anchor_list, + responsible_flag_list, gt_bboxes_list, + gt_labels_list) + + all_target_maps, all_neg_maps = results + assert num_imgs == len(all_target_maps) == len(all_neg_maps) + target_maps_list = images_to_levels(all_target_maps, num_level_anchors) + neg_maps_list = images_to_levels(all_neg_maps, num_level_anchors) + + return target_maps_list, neg_maps_list + + def _get_targets_single(self, anchors, responsible_flags, gt_bboxes, + gt_labels): + """Generate matching bounding box prior and converted GT. + + Args: + anchors (list[Tensor]): Multi-level anchors of the image. + responsible_flags (list[Tensor]): Multi-level responsible flags of + anchors + gt_bboxes (Tensor): Ground truth bboxes of single image. + gt_labels (Tensor): Ground truth labels of single image. + + Returns: + tuple: + target_map (Tensor): Predication target map of each + scale level, shape (num_total_anchors, + 5+num_classes) + neg_map (Tensor): Negative map of each scale level, + shape (num_total_anchors,) + """ + + anchor_strides = [] + for i in range(len(anchors)): + anchor_strides.append( + torch.tensor(self.featmap_strides[i], + device=gt_bboxes.device).repeat(len(anchors[i]))) + concat_anchors = torch.cat(anchors) + concat_responsible_flags = torch.cat(responsible_flags) + + anchor_strides = torch.cat(anchor_strides) + assert len(anchor_strides) == len(concat_anchors) == \ + len(concat_responsible_flags) + assign_result = self.assigner.assign(concat_anchors, + concat_responsible_flags, + gt_bboxes) + sampling_result = self.sampler.sample(assign_result, concat_anchors, + gt_bboxes) + + target_map = concat_anchors.new_zeros( + concat_anchors.size(0), self.num_attrib) + + target_map[sampling_result.pos_inds, :4] = self.bbox_coder.encode( + sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes, + anchor_strides[sampling_result.pos_inds]) + + target_map[sampling_result.pos_inds, 4] = 1 + + gt_labels_one_hot = F.one_hot( + gt_labels, num_classes=self.num_classes).float() + if self.one_hot_smoother != 0: # label smooth + gt_labels_one_hot = gt_labels_one_hot * ( + 1 - self.one_hot_smoother + ) + self.one_hot_smoother / self.num_classes + target_map[sampling_result.pos_inds, 5:] = gt_labels_one_hot[ + sampling_result.pos_assigned_gt_inds] + + neg_map = concat_anchors.new_zeros( + concat_anchors.size(0), dtype=torch.uint8) + neg_map[sampling_result.neg_inds] = 1 + + return target_map, neg_map diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index 1cd8e621fd1..f7d3332bc0d 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -20,10 +20,12 @@ from .rpn import RPN from .single_stage import SingleStageDetector from .two_stage import TwoStageDetector +from .yolo import YOLOV3 __all__ = [ 'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector', - 'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA' + 'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA', + 'YOLOV3' ] diff --git a/mmdet/models/detectors/yolo.py b/mmdet/models/detectors/yolo.py new file mode 100644 index 00000000000..240aab20f85 --- /dev/null +++ b/mmdet/models/detectors/yolo.py @@ -0,0 +1,18 @@ +# Copyright (c) 2019 Western Digital Corporation or its affiliates. + +from ..builder import DETECTORS +from .single_stage import SingleStageDetector + + +@DETECTORS.register_module() +class YOLOV3(SingleStageDetector): + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(YOLOV3, self).__init__(backbone, neck, bbox_head, train_cfg, + test_cfg, pretrained) diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py index e3d10afbe23..14cf519f2e7 100644 --- a/mmdet/models/losses/cross_entropy_loss.py +++ b/mmdet/models/losses/cross_entropy_loss.py @@ -83,7 +83,7 @@ def binary_cross_entropy(pred, if weight is not None: weight = weight.float() loss = F.binary_cross_entropy_with_logits( - pred, label.float(), weight=class_weight, reduction='none') + pred, label.float(), pos_weight=class_weight, reduction='none') # do the reduction for the weighted loss loss = weight_reduce_loss( loss, weight, reduction=reduction, avg_factor=avg_factor) diff --git a/mmdet/models/necks/__init__.py b/mmdet/models/necks/__init__.py index 2fdea02f4fa..5027bf30c01 100644 --- a/mmdet/models/necks/__init__.py +++ b/mmdet/models/necks/__init__.py @@ -6,8 +6,9 @@ from .nasfcos_fpn import NASFCOS_FPN from .pafpn import PAFPN from .rfp import RFP +from .yolo_neck import YOLOV3Neck __all__ = [ 'FPN', 'BFP', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN', 'NASFCOS_FPN', - 'RFP' + 'RFP', 'YOLOV3Neck' ] diff --git a/mmdet/models/necks/yolo_neck.py b/mmdet/models/necks/yolo_neck.py new file mode 100644 index 00000000000..c2f9b9ef385 --- /dev/null +++ b/mmdet/models/necks/yolo_neck.py @@ -0,0 +1,136 @@ +# Copyright (c) 2019 Western Digital Corporation or its affiliates. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from ..builder import NECKS + + +class DetectionBlock(nn.Module): + """Detection block in YOLO neck. + + Let out_channels = n, the DetectionBlock contains: + Six ConvLayers, 1 Conv2D Layer and 1 YoloLayer. + The first 6 ConvLayers are formed the following way: + 1x1xn, 3x3x2n, 1x1xn, 3x3x2n, 1x1xn, 3x3x2n. + The Conv2D layer is 1x1x255. + Some block will have branch after the fifth ConvLayer. + The input channel is arbitrary (in_channels) + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True) + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', negative_slope=0.1)): + super(DetectionBlock, self).__init__() + double_out_channels = out_channels * 2 + + # shortcut + cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) + self.conv1 = ConvModule(in_channels, out_channels, 1, **cfg) + self.conv2 = ConvModule( + out_channels, double_out_channels, 3, padding=1, **cfg) + self.conv3 = ConvModule(double_out_channels, out_channels, 1, **cfg) + self.conv4 = ConvModule( + out_channels, double_out_channels, 3, padding=1, **cfg) + self.conv5 = ConvModule(double_out_channels, out_channels, 1, **cfg) + + def forward(self, x): + tmp = self.conv1(x) + tmp = self.conv2(tmp) + tmp = self.conv3(tmp) + tmp = self.conv4(tmp) + out = self.conv5(tmp) + return out + + +@NECKS.register_module() +class YOLOV3Neck(nn.Module): + """The neck of YOLOV3. + + It can be treated as a simplified version of FPN. It + will take the result from Darknet backbone and do some upsampling and + concatenation. It will finally output the detection result. + + Note: + The input feats should be from top to bottom. + i.e., from high-lvl to low-lvl + But YOLOV3Neck will process them in reversed order. + i.e., from bottom (high-lvl) to top (low-lvl) + + Args: + num_scales (int): The number of scales / stages. + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True) + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + """ + + def __init__(self, + num_scales, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', negative_slope=0.1)): + super(YOLOV3Neck, self).__init__() + assert (num_scales == len(in_channels) == len(out_channels)) + self.num_scales = num_scales + self.in_channels = in_channels + self.out_channels = out_channels + + # shortcut + cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) + + # To support arbitrary scales, the code looks awful, but it works. + # Better solution is welcomed. + self.detect1 = DetectionBlock(in_channels[0], out_channels[0], **cfg) + for i in range(1, self.num_scales): + in_c, out_c = self.in_channels[i], self.out_channels[i] + self.add_module(f'conv{i}', ConvModule(in_c, out_c, 1, **cfg)) + # in_c + out_c : High-lvl feats will be cat with low-lvl feats + self.add_module(f'detect{i+1}', + DetectionBlock(in_c + out_c, out_c, **cfg)) + + def forward(self, feats): + assert len(feats) == self.num_scales + + # processed from bottom (high-lvl) to top (low-lvl) + outs = [] + out = self.detect1(feats[-1]) + outs.append(out) + + for i, x in enumerate(reversed(feats[:-1])): + conv = getattr(self, f'conv{i+1}') + tmp = conv(out) + + # Cat with low-lvl feats + tmp = F.interpolate(tmp, scale_factor=2) + tmp = torch.cat((tmp, x), 1) + + detect = getattr(self, f'detect{i+2}') + out = detect(tmp) + outs.append(out) + + return tuple(outs) + + def init_weights(self): + """Initialize the weights of module.""" + # init is done in ConvModule + pass diff --git a/tests/test_anchor.py b/tests/test_anchor.py index 3830f435c50..813852ea3e7 100644 --- a/tests/test_anchor.py +++ b/tests/test_anchor.py @@ -145,6 +145,49 @@ def test_anchor_generator_with_tuples(): assert torch.equal(anchor, anchor_tuples) +def test_yolo_anchor_generator(): + from mmdet.core.anchor import build_anchor_generator + if torch.cuda.is_available(): + device = 'cuda' + else: + device = 'cpu' + + anchor_generator_cfg = dict( + type='YOLOAnchorGenerator', + strides=[32, 16, 8], + base_sizes=[ + [(116, 90), (156, 198), (373, 326)], + [(30, 61), (62, 45), (59, 119)], + [(10, 13), (16, 30), (33, 23)], + ]) + + featmap_sizes = [(14, 18), (28, 36), (56, 72)] + anchor_generator = build_anchor_generator(anchor_generator_cfg) + + # check base anchors + expected_base_anchors = [ + torch.Tensor([[-42.0000, -29.0000, 74.0000, 61.0000], + [-62.0000, -83.0000, 94.0000, 115.0000], + [-170.5000, -147.0000, 202.5000, 179.0000]]), + torch.Tensor([[-7.0000, -22.5000, 23.0000, 38.5000], + [-23.0000, -14.5000, 39.0000, 30.5000], + [-21.5000, -51.5000, 37.5000, 67.5000]]), + torch.Tensor([[-1.0000, -2.5000, 9.0000, 10.5000], + [-4.0000, -11.0000, 12.0000, 19.0000], + [-12.5000, -7.5000, 20.5000, 15.5000]]) + ] + base_anchors = anchor_generator.base_anchors + for i, base_anchor in enumerate(base_anchors): + assert base_anchor.allclose(expected_base_anchors[i]) + + # check number of base anchors for each level + assert anchor_generator.num_base_anchors == [3, 3, 3] + + # check anchor generation + anchors = anchor_generator.grid_anchors(featmap_sizes, device) + assert len(anchors) == 3 + + def test_retina_anchor(): from mmdet.models import build_head if torch.cuda.is_available(): diff --git a/tests/test_coder.py b/tests/test_coder.py new file mode 100644 index 00000000000..b45c16e97ea --- /dev/null +++ b/tests/test_coder.py @@ -0,0 +1,21 @@ +import torch + +from mmdet.core.bbox.coder import YOLOBBoxCoder + + +def test_yolo_bbox_coder(): + coder = YOLOBBoxCoder() + bboxes = torch.Tensor([[-42., -29., 74., 61.], [-10., -29., 106., 61.], + [22., -29., 138., 61.], [54., -29., 170., 61.]]) + pred_bboxes = torch.Tensor([[0.4709, 0.6152, 0.1690, -0.4056], + [0.5399, 0.6653, 0.1162, -0.4162], + [0.4654, 0.6618, 0.1548, -0.4301], + [0.4786, 0.6197, 0.1896, -0.4479]]) + grid_size = 32 + expected_decode_bboxes = torch.Tensor( + [[-53.6102, -10.3096, 83.7478, 49.6824], + [-15.8700, -8.3901, 114.4236, 50.9693], + [11.1822, -8.0924, 146.6034, 50.4476], + [41.2068, -8.9232, 181.4236, 48.5840]]) + assert expected_decode_bboxes.allclose( + coder.decode(bboxes, pred_bboxes, grid_size)) diff --git a/tests/test_models/test_forward.py b/tests/test_models/test_forward.py index 9d12e2c5fea..0bf90a7db76 100644 --- a/tests/test_models/test_forward.py +++ b/tests/test_models/test_forward.py @@ -85,7 +85,8 @@ def test_rpn_forward(): 'foveabox/fovea_align_r50_fpn_gn-head_4x4_2x_coco.py', # 'free_anchor/retinanet_free_anchor_r50_fpn_1x_coco.py', # 'atss/atss_r50_fpn_1x_coco.py', # not ready for topk - 'reppoints/reppoints_moment_r50_fpn_1x_coco.py' + 'reppoints/reppoints_moment_r50_fpn_1x_coco.py', + 'yolo/yolov3_d53_mstrain-608_273e_coco.py' ]) def test_single_stage_forward_gpu(cfg_file): if not torch.cuda.is_available():