From 6f174a66d0885a37fbaea5485678aa1c700585d7 Mon Sep 17 00:00:00 2001 From: wangxinjiang Date: Fri, 27 Dec 2019 16:03:42 +0800 Subject: [PATCH 01/11] starting to merge fsaf --- configs/fsaf/README.md | 39 ++ configs/fsaf/fsaf_r101_fpn_1x.py | 131 +++++++ configs/fsaf/fsaf_r50_fpn_1x.py | 131 +++++++ configs/fsaf/fsaf_x101_64x4d_fpn_1x.py | 133 +++++++ mmdet/core/anchor/__init__.py | 1 + mmdet/core/anchor/anchor_tblr_target.py | 192 ++++++++++ mmdet/core/bbox/__init__.py | 5 +- mmdet/core/bbox/assigners/__init__.py | 3 +- .../bbox/assigners/effective_area_assigner.py | 167 ++++++++ .../bbox/iou_calculators/iou2d_calculator.py | 57 +++ mmdet/core/bbox/transforms.py | 65 ++++ mmdet/models/anchor_heads/__init__.py | 3 +- mmdet/models/anchor_heads/fsaf_head.py | 360 ++++++++++++++++++ mmdet/models/detectors/__init__.py | 3 +- mmdet/models/detectors/fsaf.py | 16 + 15 files changed, 1301 insertions(+), 5 deletions(-) create mode 100644 configs/fsaf/README.md create mode 100644 configs/fsaf/fsaf_r101_fpn_1x.py create mode 100644 configs/fsaf/fsaf_r50_fpn_1x.py create mode 100644 configs/fsaf/fsaf_x101_64x4d_fpn_1x.py create mode 100644 mmdet/core/anchor/anchor_tblr_target.py create mode 100644 mmdet/core/bbox/assigners/effective_area_assigner.py create mode 100644 mmdet/models/anchor_heads/fsaf_head.py create mode 100644 mmdet/models/detectors/fsaf.py diff --git a/configs/fsaf/README.md b/configs/fsaf/README.md new file mode 100644 index 00000000000..a9162426340 --- /dev/null +++ b/configs/fsaf/README.md @@ -0,0 +1,39 @@ +# Feature Selective Anchor-Free Module for Single-Shot Object Detection + +FSAF is an anchor-free method published in CVPR2019 ([https://arxiv.org/pdf/1903.00621.pdf](https://arxiv.org/pdf/1903.00621.pdf)). +Actually it is equivalent to the anchor-based method with only one anchor at each feature map position in each FPN level. +And this is how we implemented it. +Only the anchor-free branch is released for its better compatibility with the current framework and less computational budget. + +In the original paper, feature maps within the central 0.2-0.5 area of a gt box are tagged as ignored. However, +it is empirically found that a hard threshold (0.2-0.2) gives a further gain on the performance. (see the table below) + +## Main Results +### Results on R50/R101/X101-FPN + +| Backbone | ignore range | ms-train| Lr schd | Train Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download | +|:----------:| :-------: |:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:| +| R-50 | 0.2-0.5 | N | 1x | 2.97 | 0.43 | 12.3 | 35.7 | fsaf-r50-fpn-1x-20191226-ee860779ad09031f7a58d193b0438a26.pth | +| R-50 | 0.2-0.2 | N | 1x | 2.97 | 0.43 | 13.0 | 37.0 | fsaf-r50-fpn-1x-20191225-d388a744213c3bb187e073f5ccdde5d6.pth | +| R-101 | 0.2-0.5 | N | 1x | 4.87 | 0.58 | 10.6 | 37.8 | fsaf-r101-fpn-1x-20191226-736730f8db59ac0a28262034484ed57d.pth | +| R-101 | 0.2-0.2 | N | 1x | 4.87 | 0.58 | 10.8 | 39.1 | fsaf-r101-fpn-1x-20191225-e1dbbcba40933cd8fc0d0174d1b13aa7.pth | +| X-101 | 0.2-0.2 | N | 1x | 9.02 | 1.23 | 5.6 | 41.8 | fsaf-x101-64x4d-fpn-1x-20191225-82d23b4bc07f2d666eed71fe48de49a9.pth | + +**Notes:** + - *1x and 2x mean the model is trained for 12 and 24 epochs, respectively.* + - *All results are obtained with a single model and single-scale test.* + - *X-101 backbone represents ResNext-101-64x4d.* + - *All pretrained backbones use pytorch style.* + - *All models are trained on 8 Titan-XP gpus and tested on a single gpu.* + +## Citations +BibTeX reference is as follows. +``` +@inproceedings{zhu2019feature, + title={Feature Selective Anchor-Free Module for Single-Shot Object Detection}, + author={Zhu, Chenchen and He, Yihui and Savvides, Marios}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + pages={840--849}, + year={2019} +} +``` diff --git a/configs/fsaf/fsaf_r101_fpn_1x.py b/configs/fsaf/fsaf_r101_fpn_1x.py new file mode 100644 index 00000000000..720127a5e65 --- /dev/null +++ b/configs/fsaf/fsaf_r101_fpn_1x.py @@ -0,0 +1,131 @@ +# model settings +model = dict( + type='RetinaNet', + pretrained='torchvision://resnet101', + backbone=dict( + type='ResNet', + depth=101, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs=True, + num_outs=5), + bbox_head=dict( + type='FSAFHead', + num_classes=81, + in_channels=256, + stacked_convs=4, + feat_channels=256, + octave_base_scale=1, + scales_per_octave=1, + anchor_ratios=[1.0], + anchor_strides=[8, 16, 32, 64, 128], + target_normalizer=1.0, + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + reduction='none'), + loss_bbox=dict(type='IoULossTBLR', + eps=1e-6, + loss_weight=1.0, + reduction='none'))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='EffectiveAreaAssigner', + pos_area_thr=0.2, + neg_area_thr=0.2, + min_pos_iof=0.01), + allowed_border=-1, + pos_weight=-1, + debug=False) +test_cfg = dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_thr=0.5), + max_per_img=100) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) + +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/fsaf_r101_fpn_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/fsaf/fsaf_r50_fpn_1x.py b/configs/fsaf/fsaf_r50_fpn_1x.py new file mode 100644 index 00000000000..24f2d217bc4 --- /dev/null +++ b/configs/fsaf/fsaf_r50_fpn_1x.py @@ -0,0 +1,131 @@ +# model settings +model = dict( + type='FSAF', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs=True, + num_outs=5), + bbox_head=dict( + type='FSAFHead', + num_classes=81, + in_channels=256, + stacked_convs=4, + feat_channels=256, + octave_base_scale=1, + scales_per_octave=1, + anchor_ratios=[1.0], + anchor_strides=[8, 16, 32, 64, 128], + target_normalizer=1.0, + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + reduction='none'), + loss_bbox=dict(type='IoULossTBLR', + eps=1e-6, + loss_weight=1.0, + reduction='none'))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='EffectiveAreaAssigner', + pos_area_thr=0.2, + neg_area_thr=0.2, + min_pos_iof=0.01), + allowed_border=-1, + pos_weight=-1, + debug=False) +test_cfg = dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_thr=0.5), + max_per_img=100) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) + +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/fsaf_r50_fpn_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/fsaf/fsaf_x101_64x4d_fpn_1x.py b/configs/fsaf/fsaf_x101_64x4d_fpn_1x.py new file mode 100644 index 00000000000..66f2d29799e --- /dev/null +++ b/configs/fsaf/fsaf_x101_64x4d_fpn_1x.py @@ -0,0 +1,133 @@ +# model settings +model = dict( + type='RetinaNet', + pretrained='open-mmlab://resnext101_64x4d', + backbone=dict( + type='ResNeXt', + depth=101, + groups=64, + base_width=4, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs=True, + num_outs=5), + bbox_head=dict( + type='FSAFHead', + num_classes=81, + in_channels=256, + stacked_convs=4, + feat_channels=256, + octave_base_scale=1, + scales_per_octave=1, + anchor_ratios=[1.0], + anchor_strides=[8, 16, 32, 64, 128], + target_normalizer=1.0, + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + reduction='none'), + loss_bbox=dict(type='IoULossTBLR', + eps=1e-6, + loss_weight=1.0, + reduction='none'))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='EffectiveAreaAssigner', + pos_area_thr=0.2, + neg_area_thr=0.2, + min_pos_iof=0.01), + allowed_border=-1, + pos_weight=-1, + debug=False) +test_cfg = dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_thr=0.5), + max_per_img=100) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) + +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/fsaf_x101_64x4d_fpn_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/mmdet/core/anchor/__init__.py b/mmdet/core/anchor/__init__.py index d129974d96d..36724c4af2e 100644 --- a/mmdet/core/anchor/__init__.py +++ b/mmdet/core/anchor/__init__.py @@ -1,5 +1,6 @@ from .anchor_generator import AnchorGenerator, LegacyAnchorGenerator from .builder import build_anchor_generator + from .point_generator import PointGenerator from .registry import ANCHOR_GENERATORS from .utils import anchor_inside_flags, calc_region, images_to_levels diff --git a/mmdet/core/anchor/anchor_tblr_target.py b/mmdet/core/anchor/anchor_tblr_target.py new file mode 100644 index 00000000000..51670ac2293 --- /dev/null +++ b/mmdet/core/anchor/anchor_tblr_target.py @@ -0,0 +1,192 @@ +import torch + +from ..bbox import (PseudoSampler, assign_and_sample, bboxes2tblr, + build_assigner) +from ..utils import multi_apply + + +def anchor_tblr_target(anchor_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + target_normalizer, + cfg, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + sampling=True, + unmap_outputs=True): + """Compute regression and classification targets for anchors. + + Args: + anchor_list (list[list]): Multi level anchors of each image. + valid_flag_list (list[list]): Multi level valid flags of each image. + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + img_metas (list[dict]): Meta info of each image. + target_means (Iterable): Mean value of regression targets. + target_normalizer (float): Std value of regression targets. + cfg (dict): RPN train configs. + + Returns: + tuple + """ + num_imgs = len(img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + # concat all level anchors and flags to a single tensor + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + anchor_list[i] = torch.cat(anchor_list[i]) + valid_flag_list[i] = torch.cat(valid_flag_list[i]) + + # compute targets for each image + if gt_bboxes_ignore_list is None: + gt_bboxes_ignore_list = [None for _ in range(num_imgs)] + if gt_labels_list is None: + gt_labels_list = [None for _ in range(num_imgs)] + (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, + pos_inds_list, neg_inds_list, pos_assigned_gt_inds) = multi_apply( + anchor_target_single, + anchor_list, + valid_flag_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + img_metas, + target_normalizer=target_normalizer, + cfg=cfg, + sampling=sampling, + unmap_outputs=unmap_outputs) + # no valid anchors + if any([labels is None for labels in all_labels]): + return None + # sampled anchors of all images + num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) + num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) + # split targets to a list w.r.t. multiple levels + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, num_level_anchors) + bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors) + pos_assigned_gt_inds_list = images_to_levels(pos_assigned_gt_inds, + num_level_anchors) + return (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, num_total_pos, num_total_neg, + pos_assigned_gt_inds_list) + + +def images_to_levels(target, num_level_anchors): + """Convert targets by image to targets by feature level. + + [target_img0, target_img1] -> [target_level0, target_level1, ...] + """ + target = torch.stack(target, 0) + level_targets = [] + start = 0 + for n in num_level_anchors: + end = start + n + level_targets.append(target[:, start:end].squeeze(0)) + start = end + return level_targets + + +def anchor_target_single(flat_anchors, + valid_flags, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + img_meta, + target_normalizer, + cfg, + sampling=True, + unmap_outputs=True): + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + cfg.allowed_border) + if not inside_flags.any(): + return (None, ) * 6 + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + + if sampling: + assign_result, sampling_result = assign_and_sample( + anchors, gt_bboxes, gt_bboxes_ignore, None, cfg) + else: + bbox_assigner = build_assigner(cfg.assigner) + assign_result = bbox_assigner.assign(anchors, gt_bboxes, + gt_bboxes_ignore, gt_labels) + bbox_sampler = PseudoSampler() + sampling_result = bbox_sampler.sample(assign_result, anchors, + gt_bboxes) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + bbox_weights = torch.zeros_like(anchors) + labels = anchors.new_zeros(num_valid_anchors, dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + pos_assigned_gt_inds = anchors.new_full((num_valid_anchors, ), + -1, + dtype=torch.long) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + pos_bbox_targets = bboxes2tblr(sampling_result.pos_bboxes, + sampling_result.pos_gt_bboxes, + target_normalizer) + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + pos_assigned_gt_inds[pos_inds] = sampling_result.pos_assigned_gt_inds + if gt_labels is None: + labels[pos_inds] = 1 + else: + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + if cfg.pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = cfg.pos_weight + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + labels = unmap(labels, num_total_anchors, inside_flags) + label_weights = unmap(label_weights, num_total_anchors, inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + pos_assigned_gt_inds = unmap( + pos_assigned_gt_inds, num_total_anchors, inside_flags, fill=-1) + + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds, pos_assigned_gt_inds) + + +def anchor_inside_flags(flat_anchors, + valid_flags, + img_shape, + allowed_border=0): + img_h, img_w = img_shape[:2] + if allowed_border >= 0: + inside_flags = valid_flags & \ + (flat_anchors[:, 0] >= -allowed_border).type(torch.uint8) & \ + (flat_anchors[:, 1] >= -allowed_border).type(torch.uint8) & \ + (flat_anchors[:, 2] < img_w + allowed_border).type(torch.uint8) & \ + (flat_anchors[:, 3] < img_h + allowed_border).type(torch.uint8) + else: + inside_flags = valid_flags + return inside_flags + + +def unmap(data, count, inds, fill=0): + """ Unmap a subset of item (data) back to the original set of items (of + size count) """ + if data.dim() == 1: + ret = data.new_full((count, ), fill) + ret[inds] = data + else: + new_size = (count, ) + data.size()[1:] + ret = data.new_full(new_size, fill) + ret[inds, :] = data + return ret diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py index 70cc3de9052..85c7beba4fd 100644 --- a/mmdet/core/bbox/__init__.py +++ b/mmdet/core/bbox/__init__.py @@ -4,8 +4,9 @@ from .samplers import (BaseSampler, CombinedSampler, InstanceBalancedPosSampler, IoUBalancedNegSampler, PseudoSampler, RandomSampler, SamplingResult) + from .transforms import (bbox2result, bbox2roi, bbox_flip, bbox_mapping, - bbox_mapping_back, distance2bbox, roi2bbox) + bbox_mapping_back, distance2bbox, roi2bbox, tblr2bboxes) from .builder import ( # isort:skip, avoid recursive imports build_assigner, build_sampler, build_bbox_coder) @@ -18,4 +19,4 @@ 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder' -] +] \ No newline at end of file diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py index 4ed1d564318..189cc130b3b 100644 --- a/mmdet/core/bbox/assigners/__init__.py +++ b/mmdet/core/bbox/assigners/__init__.py @@ -2,10 +2,11 @@ from .assign_result import AssignResult from .atss_assigner import ATSSAssigner from .base_assigner import BaseAssigner +from .effective_area_assigner import EffectiveAreaAssigner from .max_iou_assigner import MaxIoUAssigner from .point_assigner import PointAssigner __all__ = [ 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', - 'PointAssigner', 'ATSSAssigner' + 'PointAssigner', 'ATSSAssigner', 'EffectiveAreaAssigner' ] diff --git a/mmdet/core/bbox/assigners/effective_area_assigner.py b/mmdet/core/bbox/assigners/effective_area_assigner.py new file mode 100644 index 00000000000..bf2dc22f1ef --- /dev/null +++ b/mmdet/core/bbox/assigners/effective_area_assigner.py @@ -0,0 +1,167 @@ +import torch + +from ..geometry import bbox_overlaps, bboxes_area, is_located_in, scale_boxes +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + + +class EffectiveAreaAssigner(BaseAssigner): + """Assign a corresponding gt bbox or background to each bbox. + + Each proposals will be assigned with `-1`, `0`, or a positive integer + indicating the ground truth index. + + - -1: don't care + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + pos_area_thr (float): threshold within which pixels are + labelled as positive. + neg_area_thr (float): threshold above which pixels are + labelled as positive. + min_pos_iof (float): minimum iof of a pixel with a gt to be + labelled as positive + ignore_gt_area_thr (float): threshold within which the pixels + are ignored when the gt is labelled as ignored + """ + + def __init__(self, + pos_area_thr, + neg_area_thr, + min_pos_iof=1e-2, + ignore_gt_area_thr=0.5): + self.pos_area_thr = pos_area_thr + self.neg_area_thr = neg_area_thr + self.min_pos_iof = min_pos_iof + self.ignore_gt_area_thr = ignore_gt_area_thr + + def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): + """Assign gt to bboxes. + + This method assign a gt bbox to every bbox (proposal/anchor), each bbox + will be assigned with -1, 0, or a positive number. -1 means don't care, + 0 means negative sample, positive number is the index (1-based) of + assigned gt. + + Args: + bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4). + gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`, e.g., crowd boxes in COCO. + gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). + + Returns: + :obj:`AssignResult`: The assign result. + """ + if bboxes.shape[0] == 0 or gt_bboxes.shape[0] == 0: + raise ValueError('No gt or bboxes') + bboxes = bboxes[:, :4] + + # constructing effective gt areas + gt_eff = scale_boxes(gt_bboxes, self.pos_area_thr) + # effective bboxes, i.e. center 0.2 part + bbox_centers = (bboxes[:, 2:4] + bboxes[:, 0:2] + 1) / 2 + is_bbox_in_gt = is_located_in(bbox_centers, gt_bboxes) + # the center points lie within the gt boxes + + bbox_and_gt_eff_overlaps = bbox_overlaps(bboxes, gt_eff, mode='iof') + is_bbox_in_gt_eff = is_bbox_in_gt & ( + bbox_and_gt_eff_overlaps > self.min_pos_iof) + # shape (n, k) + # the center point of effective priors should be within the gt box + + # constructing ignored gt areas + gt_ignore = scale_boxes(gt_bboxes, self.neg_area_thr) + is_bbox_in_gt_ignore = ( + bbox_overlaps(bboxes, gt_ignore, mode='iof') > self.min_pos_iof) + is_bbox_in_gt_ignore &= (~is_bbox_in_gt_eff) + # rule out center effective pixels + + gt_areas = bboxes_area(gt_bboxes) + _, sort_idx = gt_areas.sort(descending=True) + # rank all gt bbox areas so that smaller instances + # can overlay larger ones + + assigned_gt_inds = self.assign_one_hot_gt_indices( + is_bbox_in_gt_eff, is_bbox_in_gt_ignore, gt_priority=sort_idx) + + # ignored gts + if gt_bboxes_ignore is not None and gt_bboxes_ignore.numel() > 0: + gt_bboxes_ignore = scale_boxes( + gt_bboxes_ignore, scale=self.ignore_gt_area_thr) + is_bbox_in_ignored_gts = is_located_in(bbox_centers, + gt_bboxes_ignore) + is_bbox_in_ignored_gts = is_bbox_in_ignored_gts.any(dim=1) + assigned_gt_inds[is_bbox_in_ignored_gts] = -1 + + num_bboxes, num_gts = (is_bbox_in_gt_eff.size(0), + is_bbox_in_gt_eff.size(1)) + if gt_labels is not None: + assigned_labels = assigned_gt_inds.new_zeros((num_bboxes, )) + pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[ + assigned_gt_inds[pos_inds] - 1] + else: + assigned_labels = None + + return AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels) + + def assign_one_hot_gt_indices(self, + is_bbox_in_gt_eff, + is_bbox_in_gt_ignore, + gt_priority=None): + """Assign only one gt index to each prior box + + (smaller gt has higher priority) + + Args: + is_bbox_in_gt_eff: shape [num_prior, num_gt]. + bool tensor indicating the bbox center is in + the effective area of a gt (e.g. 0-0.2) + is_bbox_in_gt_ignore: shape [num_prior, num_gt]. + bool tensor indicating the bbox + center is in the ignored area of a gt (e.g. 0.2-0.5) + gt_labels: shape [num_gt]. gt labels (0-81 for COCO) + gt_priority: shape [num_gt]. gt priorities. + The gt with a higher priority is more likely to be + assigned to the bbox when the bbox match with multiple gts + + Returns: + :obj:`AssignResult`: The assign result. + """ + num_bboxes, num_gts =\ + is_bbox_in_gt_eff.size(0), is_bbox_in_gt_eff.size(1) + if gt_priority is None: + gt_priority = torch.arange(num_gts).to(is_bbox_in_gt_eff.device) + # the bigger, the more preferable to be assigned + assigned_gt_inds = is_bbox_in_gt_eff.new_full((num_bboxes, ), + 0, + dtype=torch.long) + inds_of_match = torch.any(is_bbox_in_gt_eff, dim=1) + # matched bboxes (to any gt) + inds_of_ignore = torch.any(is_bbox_in_gt_ignore, dim=1) + # ignored indices + assigned_gt_inds[inds_of_ignore] = -1 + if is_bbox_in_gt_eff.sum() == 0: # No gt match + return assigned_gt_inds + + bbox_priority = is_bbox_in_gt_eff.new_full((num_bboxes, num_gts), + -1, + dtype=torch.long) + + # Each bbox could match with multiple gts. + # The following codes deal with this + matched_bbox_and_gt_correspondence = is_bbox_in_gt_eff[inds_of_match] + # shape [nmatch, k] + matched_bbox_gt_inds =\ + torch.nonzero(matched_bbox_and_gt_correspondence)[:, 1] + # the matched gt index of each positive bbox. shape [nmatch] + bbox_priority[is_bbox_in_gt_eff] = gt_priority[matched_bbox_gt_inds] + _, argmax_priority = bbox_priority[inds_of_match].max(dim=1) + # the maximum shape [nmatch] + # effective indices + assigned_gt_inds[inds_of_match] = argmax_priority + 1 + return assigned_gt_inds diff --git a/mmdet/core/bbox/iou_calculators/iou2d_calculator.py b/mmdet/core/bbox/iou_calculators/iou2d_calculator.py index 51b8b19dfbd..e6f684e06d2 100644 --- a/mmdet/core/bbox/iou_calculators/iou2d_calculator.py +++ b/mmdet/core/bbox/iou_calculators/iou2d_calculator.py @@ -102,3 +102,60 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False): ious = overlap / (area1[:, None]) return ious + + +def scale_boxes(bboxes, scale): + """Expand an array of boxes by a given scale. + Args: + bboxes (Tensor): shape (m, 4) + scale (float): the scale factor of bboxes + + Returns: + (Tensor): shape (m, 4) scaled bboxes + """ + w_half = (bboxes[:, 2] - bboxes[:, 0] + 1) * .5 + h_half = (bboxes[:, 3] - bboxes[:, 1] + 1) * .5 + x_c = (bboxes[:, 2] + bboxes[:, 0] + 1) * .5 + y_c = (bboxes[:, 3] + bboxes[:, 1] + 1) * .5 + + w_half *= scale + h_half *= scale + + boxes_exp = torch.zeros_like(bboxes) + boxes_exp[:, 0] = x_c - w_half + boxes_exp[:, 2] = x_c + w_half - 1 + boxes_exp[:, 1] = y_c - h_half + boxes_exp[:, 3] = y_c + h_half - 1 + return boxes_exp + + +def is_located_in(points, bboxes, is_aligned=False): + """ is center a locates in box b + Then we compute the area of intersect between box_a and box_b. + Args: + points: (tensor) bounding boxes, Shape: [m,2]. + bboxes: (tensor) bounding boxes, Shape: [n,4]. + If is_aligned is ``True``, then m mush be equal to n + Return: + (tensor) intersection area, Shape: [m, n]. If is_aligned ``True``, + then shape = [m] + """ + if not is_aligned: + return (points[:, 0].unsqueeze(1) > bboxes[:, 0].unsqueeze(0)) & \ + (points[:, 0].unsqueeze(1) < bboxes[:, 2].unsqueeze(0)) & \ + (points[:, 1].unsqueeze(1) > bboxes[:, 1].unsqueeze(0)) & \ + (points[:, 1].unsqueeze(1) < bboxes[:, 3].unsqueeze(0)) + else: + return (points[:, 0] > bboxes[:, 0]) & \ + (points[:, 0] < bboxes[:, 2]) & \ + (points[:, 1] > bboxes[:, 1]) & \ + (points[:, 1] < bboxes[:, 3]) + + +def bboxes_area(bboxes): + """Compute the area of an array of boxes.""" + w = (bboxes[:, 2] - bboxes[:, 0] + 1) + h = (bboxes[:, 3] - bboxes[:, 1] + 1) + areas = w * h + + return areas diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py index 9ba80f45754..b63554771b7 100644 --- a/mmdet/core/bbox/transforms.py +++ b/mmdet/core/bbox/transforms.py @@ -111,3 +111,68 @@ def distance2bbox(points, distance, max_shape=None): x2 = x2.clamp(min=0, max=max_shape[1]) y2 = y2.clamp(min=0, max=max_shape[0]) return torch.stack([x1, y1, x2, y2], -1) + + +def bboxes2tblr(priors, gt, normalizer=1.0): + """Encode ground truth boxes + + Args: + priors (FloatTensor): Prior boxes in point form + Shape: [num_proposals,4]. + gt (FloatTensor): Coords of ground truth for each prior in point-form + Shape: [num_proposals, 4]. + normalizer (float): normalization parameter of encoded boxes + + Return: + encoded boxes (FloatTensor), Shape: [num_proposals, 4] + """ + + # dist b/t match center and prior's center + prior_centers = (priors[:, 0:2] + priors[:, 2:4] + 1) / 2 + wh = priors[:, 2:4] - priors[:, 0:2] + 1 + + xmin, ymin, xmax, ymax = gt.split(1, dim=1) + top = prior_centers[:, 1].unsqueeze(1) - ymin + bottom = ymax - prior_centers[:, 1].unsqueeze(1) + 1 + left = prior_centers[:, 0].unsqueeze(1) - xmin + right = xmax - prior_centers[:, 0].unsqueeze(1) + 1 + loc = torch.cat((top, bottom, left, right), dim=1) + w, h = torch.split(wh, 1, dim=1) + loc[:, :2] /= h + # convert them to the coordinate on the featuremap: 0 -fm_size + loc[:, 2:] /= w + return loc / normalizer + + +def tblr2bboxes(priors, tblr, normalizer=1.0, max_shape=None): + """Decode tblr outputs to prediction boxes + + Args: + priors (FloatTensor): Prior boxes in point form + Shape: [n,4]. + tblr (FloatTensor): Coords of network output in tblr form + Shape: [n, 4]. + normalizer (float): normalization parameter of encoded boxes + max_shape (tuple): Shape of the image. + + Return: + encoded boxes (FloatTensor), Shape: [n, 4] + """ + loc_decode = tblr * normalizer + prior_centers = (priors[:, 0:2] + priors[:, 2:4] + 1) / 2 + wh = priors[:, 2:4] - priors[:, 0:2] + 1 + w, h = torch.split(wh, 1, dim=1) + loc_decode[:, :2] *= h + loc_decode[:, 2:] *= w + top, bottom, left, right = loc_decode.split(1, dim=1) + xmin = prior_centers[:, 0].unsqueeze(1) - left + xmax = prior_centers[:, 0].unsqueeze(1) + right - 1 + ymin = prior_centers[:, 1].unsqueeze(1) - top + ymax = prior_centers[:, 1].unsqueeze(1) + bottom - 1 + boxes = torch.cat((xmin, ymin, xmax, ymax), dim=1) + if max_shape is not None: + boxes[:, 0].clamp_(min=0, max=max_shape[1] - 1) + boxes[:, 1].clamp_(min=0, max=max_shape[0] - 1) + boxes[:, 2].clamp_(min=0, max=max_shape[1] - 1) + boxes[:, 3].clamp_(min=0, max=max_shape[0] - 1) + return boxes diff --git a/mmdet/models/anchor_heads/__init__.py b/mmdet/models/anchor_heads/__init__.py index 3364c4c9113..2c9b646d006 100644 --- a/mmdet/models/anchor_heads/__init__.py +++ b/mmdet/models/anchor_heads/__init__.py @@ -3,6 +3,7 @@ from .fcos_head import FCOSHead from .fovea_head import FoveaHead from .free_anchor_retina_head import FreeAnchorRetinaHead +from .fsaf_head import FSAFHead from .ga_retina_head import GARetinaHead from .ga_rpn_head import GARPNHead from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead @@ -16,5 +17,5 @@ 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead', 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead', - 'ATSSHead' + 'ATSSHead', 'FSAFHead' ] diff --git a/mmdet/models/anchor_heads/fsaf_head.py b/mmdet/models/anchor_heads/fsaf_head.py new file mode 100644 index 00000000000..25b8bb4f3a5 --- /dev/null +++ b/mmdet/models/anchor_heads/fsaf_head.py @@ -0,0 +1,360 @@ +import numpy as np +import torch + +from mmdet.core import (anchor_tblr_target, force_fp32, multi_apply, + multiclass_nms, tblr2bboxes) +from ..losses import IoULoss +from ..losses.utils import weight_reduce_loss, weighted_loss +from ..registry import HEADS, LOSSES +from .retina_head import RetinaHead + + +@weighted_loss +def iou_loss_tblr(pred, target, eps=1e-6): + """Calculate the iou loss. + + Get iou loss when both the prediction and targets are + encoded in TBLR format. + + Args: + pred: shape (num_anchor, 4) + target: shape (num_anchor, 4) + eps: the minimum iou threshold + + Returns: + loss: shape (num_anchor), IoU loss + """ + xt, xb, xl, xr = torch.split(pred, 1, dim=-1) + + # the ground truth position + gt, gb, gl, gr = torch.split(target, 1, dim=-1) + + # compute the bounding box size + X = (xt + xb) * (xl + xr) # AreaX + G = (gt + gb) * (gl + gr) # AreaG + + # compute the IOU + Ih = torch.min(xt, gt) + torch.min(xb, gb) + Iw = torch.min(xl, gl) + torch.min(xr, gr) + + Inter = Ih * Iw + Union = (X + G - Inter).clamp(min=1) + # minimum area should be 1 + + IoU = Inter / Union + IoU = IoU.squeeze() + ious = IoU.clamp(min=eps) + loss = -ious.log() + return loss + + +@LOSSES.register_module +class IoULossTBLR(IoULoss): + + def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0): + super(IoULossTBLR, self).__init__(eps, reduction, loss_weight) + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + if weight is not None and not torch.any(weight > 0): + return (pred * weight).sum() # 0 + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + weight = weight.sum(dim=-1) / 4. # iou loss is a scalar! + loss = self.loss_weight * iou_loss_tblr( + pred, + target, + weight, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss + + +@HEADS.register_module +class FSAFHead(RetinaHead): + """ + FSAF anchor-free head used in [1]. + + The head contains two subnetworks. The first classifies anchor boxes and + the second regresses deltas for the anchors (num_anchors is 1 + for anchor-free methods) + + References: + .. [1] https://arxiv.org/pdf/1903.00621.pdf + + Example: + >>> import torch + >>> self = FSAFHead(11, 7) + >>> x = torch.rand(1, 7, 32, 32) + >>> cls_score, bbox_pred = self.forward_single(x) + >>> # Each anchor predicts a score for each class except background + >>> cls_per_anchor = cls_score.shape[1] / self.num_anchors + >>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors + >>> assert cls_per_anchor == (self.num_classes - 1) + >>> assert box_per_anchor == 4 + """ + + def __init__(self, + num_classes, + in_channels, + stacked_convs=4, + octave_base_scale=4, + scales_per_octave=3, + conv_cfg=None, + norm_cfg=None, + effective_threshold=0.2, + ignore_threshold=0.2, + target_normalizer=1.0, + **kwargs): + self.effective_threshold = effective_threshold + self.ignore_threshold = ignore_threshold + self.target_normalizer = target_normalizer + super(FSAFHead, self).__init__(num_classes, in_channels, stacked_convs, + octave_base_scale, scales_per_octave, + conv_cfg, norm_cfg, **kwargs) + + def forward_single(self, x): + cls_score, bbox_pred = super(FSAFHead, self).forward_single(x) + return cls_score, self.relu(bbox_pred) + # TBLR encoder only accepts positive bbox_pred + + @force_fp32(apply_to=('cls_scores', 'bbox_preds')) + def loss( + self, + cls_scores, + bbox_preds, + gt_bboxes, + gt_labels, + img_metas, + cfg, + gt_bboxes_ignore=None, + ): + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == len(self.anchor_generators) + batch_size = len(gt_bboxes) + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, img_metas, device=device) + + cls_reg_targets = anchor_tblr_target( + anchor_list, + valid_flag_list, + gt_bboxes, + img_metas, + self.target_normalizer, + cfg, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + sampling=self.sampling) + if cls_reg_targets is None: + return None + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + num_total_pos, num_total_neg, + pos_assigned_gt_inds_list) = cls_reg_targets + + num_gts = np.array(list(map(len, gt_labels))) + num_total_samples = ( + num_total_pos + num_total_neg if self.sampling else num_total_pos) + losses_cls, losses_bbox = multi_apply( + self.loss_single, + cls_scores, + bbox_preds, + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + num_total_samples=num_total_samples, + cfg=cfg) + cum_num_gts = list(np.cumsum(num_gts)) + for i, assign in enumerate(pos_assigned_gt_inds_list): + for j in range(1, batch_size): + assign[j][assign[j] >= 0] += int(cum_num_gts[j - 1]) + pos_assigned_gt_inds_list[i] = assign.flatten() + labels_list[i] = labels_list[i].flatten() + num_gts = sum(map(len, gt_labels)) + with torch.no_grad(): + loss_levels, = multi_apply( + self.collect_loss_level_single, + losses_cls, + losses_bbox, + pos_assigned_gt_inds_list, + labels_seq=torch.arange(num_gts, device=device)) + loss_levels = torch.stack(loss_levels, dim=0) + loss, argmin = loss_levels.min(dim=0) + losses_cls, losses_bbox, pos_inds = multi_apply( + self.reassign_loss_single, + losses_cls, + losses_bbox, + pos_assigned_gt_inds_list, + labels_list, + list(range(len(losses_cls))), + min_levels=argmin) + + num_pos = torch.cat(pos_inds, 0).sum().float() + acc = self.calculate_accuracy(cls_scores, labels_list, pos_inds) + for i in range(len(losses_cls)): + losses_cls[i] /= num_pos + losses_bbox[i] /= num_pos + return dict( + loss_cls=losses_cls, + loss_bbox=losses_bbox, + num_pos=num_pos / batch_size, + accuracy=acc) + + def calculate_accuracy(self, cls_scores, labels_list, pos_inds): + with torch.no_grad(): + num_pos = torch.cat(pos_inds, 0).sum().float() + num_class = cls_scores[0].size(1) + scores = [ + cls.permute(0, 2, 3, 1).reshape(-1, num_class)[pos] + for cls, pos in zip(cls_scores, pos_inds) + ] + labels = [ + l.reshape(-1)[pos] for l, pos in zip(labels_list, pos_inds) + ] + + def argmax(x): + return x.argmax(1) if x.numel() > 0 else -100 + + num_correct = sum([(argmax(score) + 1 == label).sum() + for score, label in zip(scores, labels)]) + return num_correct.float() / (num_pos + 1e-3) + + def collect_loss_level_single(self, cls_loss, reg_loss, + pos_assigned_gt_inds, labels_seq): + """Get the average loss in each FPN level w.r.t. each gt label + + Args: + cls_loss (tensor): classification loss of each feature map pixel, + shape (num_anchor, num_class) + reg_loss (tensor): regression loss of each feature map pixel, + shape (num_anchor) + pos_assigned_gt_inds (tensor): shape (num_anchor), indicating + which gt the prior is assigned to (-1: no assignment) + labels_seq: The rank of labels + + Returns: + + """ + loss = cls_loss.sum(dim=-1) + reg_loss + # total loss at each feature map point + match = ( + pos_assigned_gt_inds.reshape(-1).unsqueeze(1) == + labels_seq.unsqueeze(0)) + loss_ceiling = loss.new_zeros(1).squeeze() + 1e6 + # default loss value for a layer where no anchor is positive + losses_ = torch.stack([ + torch.mean(loss[match[:, i]]) + if match[:, i].sum() > 0 else loss_ceiling for i in labels_seq + ]) + return losses_, + + def reassign_loss_single(self, cls_loss, reg_loss, pos_assigned_gt_inds, + labels, level, min_levels): + """Reassign loss values at each level. + + Reassign loss values at each level by masking those where the + pre-calculated loss is too large + + Args: + cls_loss (tensor): shape (num_anchors, num_classes) + classification loss + reg_loss (tensor): shape (num_anchors) regression loss + pos_assigned_gt_inds (tensor): shape (num_anchors), + the gt indices that each positive anchor corresponds to. + (-1 if it is a negative one) + labels (tensor): shape (num_anchors). Label assigned to each pixel + level (int): the current level index in the + pyramid (0-4 for RetinaNet) + min_levels (tensor): shape (num_gts), + the best-matching level for each gt + + Returns: + cls_loss: shape (num_anchors, num_classes). + Corrected classification loss + reg_loss: shape (num_anchors). Corrected regression loss + keep_indices: shape (num_anchors). Indicating final postive anchors + """ + + unmatch_gt_inds = torch.nonzero(min_levels != level) + # gts indices that unmatch with the current level + match_gt_inds = torch.nonzero(min_levels == level) + loc_weight = cls_loss.new_ones(cls_loss.size(0)) + cls_weight = cls_loss.new_ones(cls_loss.size(0), cls_loss.size(1)) + zeroing_indices = (pos_assigned_gt_inds.view( + -1, 1) == unmatch_gt_inds.view(1, -1)).any(dim=-1) + keep_indices = (pos_assigned_gt_inds.view(-1, 1) == match_gt_inds.view( + 1, -1)).any(dim=-1) + loc_weight[zeroing_indices] = 0 + + # Only the weight corresponding to the label is + # zeroed out if not selected + zeroing_labels = labels[zeroing_indices] - 1 + # the original labels assigned to the anchor box + assert (zeroing_labels >= 0).all() + cls_weight[zeroing_indices, zeroing_labels] = 0 + + # weighted loss for both cls and reg loss + cls_loss = weight_reduce_loss(cls_loss, cls_weight, reduction='sum') + reg_loss = weight_reduce_loss(reg_loss, loc_weight, reduction='sum') + return cls_loss, reg_loss, keep_indices + + def get_bboxes_single(self, + cls_score_list, + bbox_pred_list, + mlvl_anchors, + img_shape, + scale_factor, + cfg, + rescale=False): + """ + Transform outputs for a single batch item into labeled boxes. + """ + assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) + mlvl_bboxes = [] + mlvl_scores = [] + for cls_score, bbox_pred, anchors in zip(cls_score_list, + bbox_pred_list, mlvl_anchors): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + scores = cls_score.softmax(-1) + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + nms_pre = cfg.get('nms_pre', -1) + if nms_pre > 0 and scores.shape[0] > nms_pre: + # Get maximum scores for foreground classes. + if self.use_sigmoid_cls: + max_scores, _ = scores.max(dim=1) + else: + max_scores, _ = scores[:, 1:].max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + anchors = anchors[topk_inds, :] + bbox_pred = bbox_pred[topk_inds, :] + scores = scores[topk_inds, :] + bboxes = tblr2bboxes(anchors, bbox_pred, self.target_normalizer, + img_shape) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_bboxes = torch.cat(mlvl_bboxes) + if rescale: + mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) + mlvl_scores = torch.cat(mlvl_scores) + if self.use_sigmoid_cls: + # Add a dummy background class to the front when using sigmoid + padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) + mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) + det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, + cfg.score_thr, cfg.nms, + cfg.max_per_img) + return det_bboxes, det_labels diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index 852c9d60035..7874bce0e56 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -5,6 +5,7 @@ from .faster_rcnn import FasterRCNN from .fcos import FCOS from .fovea import FOVEA +from .fsaf import FSAF from .grid_rcnn import GridRCNN from .htc import HybridTaskCascade from .mask_rcnn import MaskRCNN @@ -19,5 +20,5 @@ 'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector', - 'FOVEA' + 'FOVEA', 'FSAF' ] diff --git a/mmdet/models/detectors/fsaf.py b/mmdet/models/detectors/fsaf.py new file mode 100644 index 00000000000..1f73c349c67 --- /dev/null +++ b/mmdet/models/detectors/fsaf.py @@ -0,0 +1,16 @@ +from ..registry import DETECTORS +from .single_stage import SingleStageDetector + + +@DETECTORS.register_module +class FSAF(SingleStageDetector): + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(FSAF, self).__init__(backbone, neck, bbox_head, train_cfg, + test_cfg, pretrained) From f1ba30cb030007ba45bcdd3c8a3982e3e267af14 Mon Sep 17 00:00:00 2001 From: wangxinjiang Date: Wed, 22 Apr 2020 16:53:15 +0800 Subject: [PATCH 02/11] changed configs --- configs/fsaf/fsaf_r50_fpn_1x.py | 131 ------------------------ configs/fsaf/fsaf_r50_fpn_1x_coco.py | 56 +++++++++++ configs/fsaf/fsaf_x101_64x4d_fpn_1x.py | 133 ------------------------- 3 files changed, 56 insertions(+), 264 deletions(-) delete mode 100644 configs/fsaf/fsaf_r50_fpn_1x.py create mode 100644 configs/fsaf/fsaf_r50_fpn_1x_coco.py delete mode 100644 configs/fsaf/fsaf_x101_64x4d_fpn_1x.py diff --git a/configs/fsaf/fsaf_r50_fpn_1x.py b/configs/fsaf/fsaf_r50_fpn_1x.py deleted file mode 100644 index 24f2d217bc4..00000000000 --- a/configs/fsaf/fsaf_r50_fpn_1x.py +++ /dev/null @@ -1,131 +0,0 @@ -# model settings -model = dict( - type='FSAF', - pretrained='torchvision://resnet50', - backbone=dict( - type='ResNet', - depth=50, - num_stages=4, - out_indices=(0, 1, 2, 3), - frozen_stages=1, - style='pytorch'), - neck=dict( - type='FPN', - in_channels=[256, 512, 1024, 2048], - out_channels=256, - start_level=1, - add_extra_convs=True, - num_outs=5), - bbox_head=dict( - type='FSAFHead', - num_classes=81, - in_channels=256, - stacked_convs=4, - feat_channels=256, - octave_base_scale=1, - scales_per_octave=1, - anchor_ratios=[1.0], - anchor_strides=[8, 16, 32, 64, 128], - target_normalizer=1.0, - loss_cls=dict( - type='FocalLoss', - use_sigmoid=True, - gamma=2.0, - alpha=0.25, - loss_weight=1.0, - reduction='none'), - loss_bbox=dict(type='IoULossTBLR', - eps=1e-6, - loss_weight=1.0, - reduction='none'))) -# training and testing settings -train_cfg = dict( - assigner=dict( - type='EffectiveAreaAssigner', - pos_area_thr=0.2, - neg_area_thr=0.2, - min_pos_iof=0.01), - allowed_border=-1, - pos_weight=-1, - debug=False) -test_cfg = dict( - nms_pre=1000, - min_bbox_size=0, - score_thr=0.05, - nms=dict(type='nms', iou_thr=0.5), - max_per_img=100) -# dataset settings -dataset_type = 'CocoDataset' -data_root = 'data/coco/' -img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) -train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations', with_bbox=True), - dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), - dict(type='RandomFlip', flip_ratio=0.5), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size_divisor=32), - dict(type='DefaultFormatBundle'), - dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), -] -test_pipeline = [ - dict(type='LoadImageFromFile'), - dict( - type='MultiScaleFlipAug', - img_scale=(1333, 800), - flip=False, - transforms=[ - dict(type='Resize', keep_ratio=True), - dict(type='RandomFlip'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size_divisor=32), - dict(type='ImageToTensor', keys=['img']), - dict(type='Collect', keys=['img']), - ]) -] -data = dict( - imgs_per_gpu=2, - workers_per_gpu=2, - train=dict( - type=dataset_type, - ann_file=data_root + 'annotations/instances_train2017.json', - img_prefix=data_root + 'train2017/', - pipeline=train_pipeline), - val=dict( - type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', - pipeline=test_pipeline), - test=dict( - type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', - pipeline=test_pipeline)) - -optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) -optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2)) -# learning policy -lr_config = dict( - policy='step', - warmup='linear', - warmup_iters=500, - warmup_ratio=1.0 / 3, - step=[8, 11]) -checkpoint_config = dict(interval=1) -# yapf:disable -log_config = dict( - interval=50, - hooks=[ - dict(type='TextLoggerHook'), - # dict(type='TensorboardLoggerHook') - ]) -# yapf:enable -# runtime settings -total_epochs = 12 -dist_params = dict(backend='nccl') -log_level = 'INFO' -work_dir = './work_dirs/fsaf_r50_fpn_1x' -load_from = None -resume_from = None -workflow = [('train', 1)] diff --git a/configs/fsaf/fsaf_r50_fpn_1x_coco.py b/configs/fsaf/fsaf_r50_fpn_1x_coco.py new file mode 100644 index 00000000000..637967f9ea9 --- /dev/null +++ b/configs/fsaf/fsaf_r50_fpn_1x_coco.py @@ -0,0 +1,56 @@ +_base_ = [ + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +# model settings +model = dict( + type='FSAF', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs=True, + num_outs=5), + bbox_head=dict( + type='FSAFHead', + num_classes=81, + in_channels=256, + stacked_convs=4, + feat_channels=256, + octave_base_scale=1, + scales_per_octave=1, + anchor_ratios=[1.0], + anchor_strides=[8, 16, 32, 64, 128], + target_normalizer=1.0, + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + reduction='none'), + loss_bbox=dict(type='IoULossTBLR', + eps=1e-6, + loss_weight=1.0, + reduction='none'))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='EffectiveAreaAssigner', + pos_area_thr=0.2, + neg_area_thr=0.2, + min_pos_iof=0.01), + allowed_border=-1, + pos_weight=-1, + debug=False) +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2)) \ No newline at end of file diff --git a/configs/fsaf/fsaf_x101_64x4d_fpn_1x.py b/configs/fsaf/fsaf_x101_64x4d_fpn_1x.py deleted file mode 100644 index 66f2d29799e..00000000000 --- a/configs/fsaf/fsaf_x101_64x4d_fpn_1x.py +++ /dev/null @@ -1,133 +0,0 @@ -# model settings -model = dict( - type='RetinaNet', - pretrained='open-mmlab://resnext101_64x4d', - backbone=dict( - type='ResNeXt', - depth=101, - groups=64, - base_width=4, - num_stages=4, - out_indices=(0, 1, 2, 3), - frozen_stages=1, - style='pytorch'), - neck=dict( - type='FPN', - in_channels=[256, 512, 1024, 2048], - out_channels=256, - start_level=1, - add_extra_convs=True, - num_outs=5), - bbox_head=dict( - type='FSAFHead', - num_classes=81, - in_channels=256, - stacked_convs=4, - feat_channels=256, - octave_base_scale=1, - scales_per_octave=1, - anchor_ratios=[1.0], - anchor_strides=[8, 16, 32, 64, 128], - target_normalizer=1.0, - loss_cls=dict( - type='FocalLoss', - use_sigmoid=True, - gamma=2.0, - alpha=0.25, - loss_weight=1.0, - reduction='none'), - loss_bbox=dict(type='IoULossTBLR', - eps=1e-6, - loss_weight=1.0, - reduction='none'))) -# training and testing settings -train_cfg = dict( - assigner=dict( - type='EffectiveAreaAssigner', - pos_area_thr=0.2, - neg_area_thr=0.2, - min_pos_iof=0.01), - allowed_border=-1, - pos_weight=-1, - debug=False) -test_cfg = dict( - nms_pre=1000, - min_bbox_size=0, - score_thr=0.05, - nms=dict(type='nms', iou_thr=0.5), - max_per_img=100) -# dataset settings -dataset_type = 'CocoDataset' -data_root = 'data/coco/' -img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) -train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations', with_bbox=True), - dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), - dict(type='RandomFlip', flip_ratio=0.5), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size_divisor=32), - dict(type='DefaultFormatBundle'), - dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), -] -test_pipeline = [ - dict(type='LoadImageFromFile'), - dict( - type='MultiScaleFlipAug', - img_scale=(1333, 800), - flip=False, - transforms=[ - dict(type='Resize', keep_ratio=True), - dict(type='RandomFlip'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size_divisor=32), - dict(type='ImageToTensor', keys=['img']), - dict(type='Collect', keys=['img']), - ]) -] -data = dict( - imgs_per_gpu=2, - workers_per_gpu=2, - train=dict( - type=dataset_type, - ann_file=data_root + 'annotations/instances_train2017.json', - img_prefix=data_root + 'train2017/', - pipeline=train_pipeline), - val=dict( - type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', - pipeline=test_pipeline), - test=dict( - type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', - pipeline=test_pipeline)) - -optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) -optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2)) -# learning policy -lr_config = dict( - policy='step', - warmup='linear', - warmup_iters=500, - warmup_ratio=1.0 / 3, - step=[8, 11]) -checkpoint_config = dict(interval=1) -# yapf:disable -log_config = dict( - interval=50, - hooks=[ - dict(type='TextLoggerHook'), - # dict(type='TensorboardLoggerHook') - ]) -# yapf:enable -# runtime settings -total_epochs = 12 -dist_params = dict(backend='nccl') -log_level = 'INFO' -work_dir = './work_dirs/fsaf_x101_64x4d_fpn_1x' -load_from = None -resume_from = None -workflow = [('train', 1)] From 32d49a16b7ccd758cb512d6119650e93a3725bfa Mon Sep 17 00:00:00 2001 From: wangxinjiang Date: Wed, 22 Apr 2020 18:08:24 +0800 Subject: [PATCH 03/11] changed bbox_coder --- mmdet/core/bbox/__init__.py | 5 +- mmdet/core/bbox/coder/tblr_bbox_coder.py | 175 ++++++++++++++++++ .../anchor_heads}/anchor_tblr_target.py | 6 +- mmdet/models/anchor_heads/fsaf_head.py | 147 ++++++++++++++- 4 files changed, 322 insertions(+), 11 deletions(-) create mode 100644 mmdet/core/bbox/coder/tblr_bbox_coder.py rename mmdet/{core/anchor => models/anchor_heads}/anchor_tblr_target.py (97%) diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py index 85c7beba4fd..6452e4fe2df 100644 --- a/mmdet/core/bbox/__init__.py +++ b/mmdet/core/bbox/__init__.py @@ -6,7 +6,8 @@ PseudoSampler, RandomSampler, SamplingResult) from .transforms import (bbox2result, bbox2roi, bbox_flip, bbox_mapping, - bbox_mapping_back, distance2bbox, roi2bbox, tblr2bboxes) + bbox_mapping_back, distance2bbox, roi2bbox, + tblr2bboxes, bboxes2tblr) from .builder import ( # isort:skip, avoid recursive imports build_assigner, build_sampler, build_bbox_coder) @@ -18,5 +19,5 @@ 'SamplingResult', 'build_assigner', 'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder', - 'DeltaXYWHBBoxCoder' + 'DeltaXYWHBBoxCoder', 'tblr2bboxes', 'bboxes2tblr' ] \ No newline at end of file diff --git a/mmdet/core/bbox/coder/tblr_bbox_coder.py b/mmdet/core/bbox/coder/tblr_bbox_coder.py new file mode 100644 index 00000000000..d906adca610 --- /dev/null +++ b/mmdet/core/bbox/coder/tblr_bbox_coder.py @@ -0,0 +1,175 @@ +import numpy as np +import torch + +from ..registry import BBOX_CODERS +from .base_bbox_coder import BaseBBoxCoder + + +@BBOX_CODERS.register_module +class DeltaXYWHBBoxCoder(BaseBBoxCoder): + """Delta XYWH BBox coder + + Following the practice in R-CNN [1]_, this coder encodes bbox (x1, y1, x2, + y2) into delta (dx, dy, dw, dh) and decodes delta (dx, dy, dw, dh) + back to original bbox (x1, y1, x2, y2). + + References: + .. [1] https://arxiv.org/abs/1311.2524 + + Args: + target_means (Sequence[float]): denormalizing means of target for + delta coordinates + target_stds (Sequence[float]): denormalizing standard deviation of + target for delta coordinates + """ + + def __init__(self, + target_means=(0., 0., 0., 0.), + target_stds=(1., 1., 1., 1.)): + super(BaseBBoxCoder, self).__init__() + self.means = target_means + self.stds = target_stds + + def encode(self, bboxes, gt_bboxes): + assert bboxes.size(0) == gt_bboxes.size(0) + assert bboxes.size(-1) == gt_bboxes.size(-1) == 4 + encoded_bboxes = bbox2delta(bboxes, gt_bboxes, self.means, self.stds) + return encoded_bboxes + + def decode(self, + bboxes, + pred_bboxes, + max_shape=None, + wh_ratio_clip=16 / 1000): + assert pred_bboxes.size(0) == bboxes.size(0) + decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, self.stds, + max_shape, wh_ratio_clip) + + return decoded_bboxes + + +def bbox2delta(proposals, gt, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.)): + """Compute deltas of proposals w.r.t. gt. + + We usually compute the deltas of x, y, w, h of proposals w.r.t ground + truth bboxes to get regression target. + This is the inverse function of `delta2bbox()` + + Args: + proposals (Tensor): Boxes to be transformed, shape (N, ..., 4) + gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4) + means (Sequence[float]): Denormalizing means for delta coordinates + stds (Sequence[float]): Denormalizing standard deviation for delta + coordinates + + Returns: + Tensor: deltas with shape (N, 4), where columns represent dx, dy, + dw, dh. + + """ + assert proposals.size() == gt.size() + + proposals = proposals.float() + gt = gt.float() + px = (proposals[..., 0] + proposals[..., 2]) * 0.5 + py = (proposals[..., 1] + proposals[..., 3]) * 0.5 + pw = proposals[..., 2] - proposals[..., 0] + ph = proposals[..., 3] - proposals[..., 1] + + gx = (gt[..., 0] + gt[..., 2]) * 0.5 + gy = (gt[..., 1] + gt[..., 3]) * 0.5 + gw = gt[..., 2] - gt[..., 0] + gh = gt[..., 3] - gt[..., 1] + + dx = (gx - px) / pw + dy = (gy - py) / ph + dw = torch.log(gw / pw) + dh = torch.log(gh / ph) + deltas = torch.stack([dx, dy, dw, dh], dim=-1) + + means = deltas.new_tensor(means).unsqueeze(0) + stds = deltas.new_tensor(stds).unsqueeze(0) + deltas = deltas.sub_(means).div_(stds) + + return deltas + + +def delta2bbox(rois, + deltas, + means=(0., 0., 0., 0.), + stds=(1., 1., 1., 1.), + max_shape=None, + wh_ratio_clip=16 / 1000): + """Apply deltas to shift/scale base boxes. + + Typically the rois are anchor or proposed bounding boxes and the deltas are + network outputs used to shift/scale those boxes. + This is the inverse function of `bbox2delta()` + + Args: + rois (Tensor): Boxes to be transformed. Has shape (N, 4) + deltas (Tensor): Encoded offsets with respect to each roi. + Has shape (N, 4 * num_classes). Note N = num_anchors * W * H when + rois is a grid of anchors. Offset encoding follows [1]_. + means (Sequence[float]): Denormalizing means for delta coordinates + stds (Sequence[float]): Denormalizing standard deviation for delta + coordinates + max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W) + wh_ratio_clip (float): Maximum aspect ratio for boxes. + + Returns: + Tensor: Boxes with shape (N, 4), where columns represent + tl_x, tl_y, br_x, br_y. + + References: + .. [1] https://arxiv.org/abs/1311.2524 + + Example: + >>> rois = torch.Tensor([[ 0., 0., 1., 1.], + >>> [ 0., 0., 1., 1.], + >>> [ 0., 0., 1., 1.], + >>> [ 5., 5., 5., 5.]]) + >>> deltas = torch.Tensor([[ 0., 0., 0., 0.], + >>> [ 1., 1., 1., 1.], + >>> [ 0., 0., 2., -1.], + >>> [ 0.7, -1.9, -0.5, 0.3]]) + >>> delta2bbox(rois, deltas, max_shape=(32, 32)) + tensor([[0.0000, 0.0000, 1.0000, 1.0000], + [0.1409, 0.1409, 2.8591, 2.8591], + [0.0000, 0.3161, 4.1945, 0.6839], + [5.0000, 5.0000, 5.0000, 5.0000]]) + """ + means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4) + stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4) + denorm_deltas = deltas * stds + means + dx = denorm_deltas[:, 0::4] + dy = denorm_deltas[:, 1::4] + dw = denorm_deltas[:, 2::4] + dh = denorm_deltas[:, 3::4] + max_ratio = np.abs(np.log(wh_ratio_clip)) + dw = dw.clamp(min=-max_ratio, max=max_ratio) + dh = dh.clamp(min=-max_ratio, max=max_ratio) + # Compute center of each roi + px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx) + py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy) + # Compute width/height of each roi + pw = (rois[:, 2] - rois[:, 0]).unsqueeze(1).expand_as(dw) + ph = (rois[:, 3] - rois[:, 1]).unsqueeze(1).expand_as(dh) + # Use exp(network energy) to enlarge/shrink each roi + gw = pw * dw.exp() + gh = ph * dh.exp() + # Use network energy to shift the center of each roi + gx = torch.addcmul(px, 1, pw, dx) # gx = px + pw * dx + gy = torch.addcmul(py, 1, ph, dy) # gy = py + ph * dy + # Convert center-xy/width/height to top-left, bottom-right + x1 = gx - gw * 0.5 + y1 = gy - gh * 0.5 + x2 = gx + gw * 0.5 + y2 = gy + gh * 0.5 + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas) + return bboxes diff --git a/mmdet/core/anchor/anchor_tblr_target.py b/mmdet/models/anchor_heads/anchor_tblr_target.py similarity index 97% rename from mmdet/core/anchor/anchor_tblr_target.py rename to mmdet/models/anchor_heads/anchor_tblr_target.py index 51670ac2293..f12acd8fa65 100644 --- a/mmdet/core/anchor/anchor_tblr_target.py +++ b/mmdet/models/anchor_heads/anchor_tblr_target.py @@ -1,8 +1,8 @@ import torch -from ..bbox import (PseudoSampler, assign_and_sample, bboxes2tblr, - build_assigner) -from ..utils import multi_apply +from mmdet.core.bbox import (PseudoSampler, assign_and_sample, bboxes2tblr, + build_assigner) +from mmdet.core.utils import multi_apply def anchor_tblr_target(anchor_list, diff --git a/mmdet/models/anchor_heads/fsaf_head.py b/mmdet/models/anchor_heads/fsaf_head.py index 25b8bb4f3a5..2271499e2c5 100644 --- a/mmdet/models/anchor_heads/fsaf_head.py +++ b/mmdet/models/anchor_heads/fsaf_head.py @@ -1,8 +1,10 @@ import numpy as np import torch +from mmdet.core import (anchor_inside_flags, build_anchor_generator, + build_assigner, build_bbox_coder, build_sampler, + force_fp32, images_to_levels, multi_apply, + multiclass_nms, unmap) -from mmdet.core import (anchor_tblr_target, force_fp32, multi_apply, - multiclass_nms, tblr2bboxes) from ..losses import IoULoss from ..losses.utils import weight_reduce_loss, weighted_loss from ..registry import HEADS, LOSSES @@ -106,8 +108,6 @@ def __init__(self, num_classes, in_channels, stacked_convs=4, - octave_base_scale=4, - scales_per_octave=3, conv_cfg=None, norm_cfg=None, effective_threshold=0.2, @@ -118,7 +118,6 @@ def __init__(self, self.ignore_threshold = ignore_threshold self.target_normalizer = target_normalizer super(FSAFHead, self).__init__(num_classes, in_channels, stacked_convs, - octave_base_scale, scales_per_octave, conv_cfg, norm_cfg, **kwargs) def forward_single(self, x): @@ -126,6 +125,142 @@ def forward_single(self, x): return cls_score, self.relu(bbox_pred) # TBLR encoder only accepts positive bbox_pred + def _get_targets_single(self, + flat_anchors, + valid_flags, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + img_meta, + label_channels=1, + unmap_outputs=True): + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg.allowed_border) + if not inside_flags.any(): + return (None, ) * 6 + # assign gt and sample anchors + anchors = flat_anchors[inside_flags.type(torch.bool), :] + + assign_result = self.assigner.assign( + anchors, gt_bboxes, gt_bboxes_ignore, + None if self.sampling else gt_labels) + sampling_result = self.sampler.sample(assign_result, anchors, + gt_bboxes) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + bbox_weights = torch.zeros_like(anchors) + labels = anchors.new_full((num_valid_anchors, ), + self.background_label, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + pos_gt_inds = anchors.new_full((num_valid_anchors,), + -1, + dtype=torch.long) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + if not self.reg_decoded_bbox: + pos_bbox_targets = self.bbox_coder.encode( + sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) + else: + pos_bbox_targets = sampling_result.pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + pos_gt_inds[pos_inds] = sampling_result.pos_assigned_gt_inds + if gt_labels is None: + # only rpn gives gt_labels as None, this time FG is 1 + labels[pos_inds] = 1 + else: + labels[pos_inds] = gt_labels[ + sampling_result.pos_assigned_gt_inds] + if self.train_cfg.pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg.pos_weight + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + labels = unmap(labels, num_total_anchors, inside_flags) + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + pos_gt_inds = unmap( + pos_gt_inds, num_total_anchors, inside_flags, fill=-1) + + + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds, pos_gt_inds) + + def get_targets(self, + anchor_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + label_channels=1, + unmap_outputs=True): + """Compute regression and classification targets for anchors in + multiple images. + + """ + num_imgs = len(img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + # concat all level anchors to a single tensor + concat_anchor_list = [] + concat_valid_flag_list = [] + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + concat_anchor_list.append(torch.cat(anchor_list[i])) + concat_valid_flag_list.append(torch.cat(valid_flag_list[i])) + + # compute targets for each image + if gt_bboxes_ignore_list is None: + gt_bboxes_ignore_list = [None for _ in range(num_imgs)] + if gt_labels_list is None: + gt_labels_list = [None for _ in range(num_imgs)] + (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, + pos_inds_list, neg_inds_list, pos_assigned_gt_inds) = multi_apply( + self._get_targets_single, + concat_anchor_list, + concat_valid_flag_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + img_metas, + label_channels=label_channels, + unmap_outputs=unmap_outputs) + # no valid anchors + if any([labels is None for labels in all_labels]): + return None + # sampled anchors of all images + num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) + num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) + # split targets to a list w.r.t. multiple levels + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, + num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, + num_level_anchors) + bbox_weights_list = images_to_levels(all_bbox_weights, + num_level_anchors) + pos_assigned_gt_inds_list = images_to_levels(pos_assigned_gt_inds, + num_level_anchors) + return (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, num_total_pos, num_total_neg, + pos_assigned_gt_inds_list) + + @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def loss( self, @@ -144,7 +279,7 @@ def loss( anchor_list, valid_flag_list = self.get_anchors( featmap_sizes, img_metas, device=device) - cls_reg_targets = anchor_tblr_target( + cls_reg_targets = self.get_targets( anchor_list, valid_flag_list, gt_bboxes, From 229088165e7f1e9684dd20f05baf10ed65acb350 Mon Sep 17 00:00:00 2001 From: wangxinjiang Date: Thu, 23 Apr 2020 16:42:02 +0800 Subject: [PATCH 04/11] get_targets fun in AnchorHead class is modified to enable additional returns from _get_target_single --- mmdet/models/anchor_heads/anchor_head.py | 33 ++++++++++++++++-------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/mmdet/models/anchor_heads/anchor_head.py b/mmdet/models/anchor_heads/anchor_head.py index c5a1b4bd92f..668da6378c1 100644 --- a/mmdet/models/anchor_heads/anchor_head.py +++ b/mmdet/models/anchor_heads/anchor_head.py @@ -281,6 +281,11 @@ def get_targets(self, bbox_weights_list (list[Tensor]): BBox weights of each level num_total_pos (int): Number of positive samples in all images num_total_neg (int): Number of negative samples in all images + additional_returns: This function enables user-defined returns from + `self._get_targets_single`. These returns are currently refined + to properties at each feature map (i.e. having HxW dimension). + The results will be concatenated after the end + """ num_imgs = len(img_metas) assert len(anchor_list) == len(valid_flag_list) == num_imgs @@ -300,17 +305,19 @@ def get_targets(self, gt_bboxes_ignore_list = [None for _ in range(num_imgs)] if gt_labels_list is None: gt_labels_list = [None for _ in range(num_imgs)] + results = multi_apply( + self._get_targets_single, + concat_anchor_list, + concat_valid_flag_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + img_metas, + label_channels=label_channels, + unmap_outputs=unmap_outputs) (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, - pos_inds_list, neg_inds_list) = multi_apply( - self._get_targets_single, - concat_anchor_list, - concat_valid_flag_list, - gt_bboxes_list, - gt_bboxes_ignore_list, - gt_labels_list, - img_metas, - label_channels=label_channels, - unmap_outputs=unmap_outputs) + pos_inds_list, neg_inds_list) = results[:6] + rest_results = list(results[6:]) # user-added return values # no valid anchors if any([labels is None for labels in all_labels]): return None @@ -325,8 +332,12 @@ def get_targets(self, num_level_anchors) bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors) + for i, r in enumerate(rest_results): # user-added return values + rest_results[i] = images_to_levels(r, num_level_anchors) + return (labels_list, label_weights_list, bbox_targets_list, - bbox_weights_list, num_total_pos, num_total_neg) + bbox_weights_list, num_total_pos, num_total_neg) \ + + tuple(rest_results) def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights, bbox_targets, bbox_weights, num_total_samples): From 544c1691f3def809120db77030f5c68bdf67c46b Mon Sep 17 00:00:00 2001 From: wangxinjiang Date: Thu, 23 Apr 2020 16:44:16 +0800 Subject: [PATCH 05/11] effective area assigner and tblr bbox_coder are implemented --- mmdet/core/bbox/__init__.py | 11 +- .../bbox/assigners/effective_area_assigner.py | 112 +++++++-- mmdet/core/bbox/coder/__init__.py | 4 +- mmdet/core/bbox/coder/tblr_bbox_coder.py | 213 +++++++----------- .../bbox/iou_calculators/iou2d_calculator.py | 56 ----- 5 files changed, 175 insertions(+), 221 deletions(-) diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py index 6452e4fe2df..ccb506db8e2 100644 --- a/mmdet/core/bbox/__init__.py +++ b/mmdet/core/bbox/__init__.py @@ -1,17 +1,16 @@ from .assigners import AssignResult, BaseAssigner, MaxIoUAssigner -from .coder import BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder +from .builder import ( # isort:skip, avoid recursive imports + build_assigner, build_sampler, build_bbox_coder) +from .coder import BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder, \ + TBLRBBoxCoder from .iou_calculators import BboxOverlaps2D, bbox_overlaps from .samplers import (BaseSampler, CombinedSampler, InstanceBalancedPosSampler, IoUBalancedNegSampler, PseudoSampler, RandomSampler, SamplingResult) - from .transforms import (bbox2result, bbox2roi, bbox_flip, bbox_mapping, bbox_mapping_back, distance2bbox, roi2bbox, tblr2bboxes, bboxes2tblr) -from .builder import ( # isort:skip, avoid recursive imports - build_assigner, build_sampler, build_bbox_coder) - __all__ = [ 'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner', 'AssignResult', 'BaseSampler', 'PseudoSampler', 'RandomSampler', @@ -19,5 +18,5 @@ 'SamplingResult', 'build_assigner', 'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder', - 'DeltaXYWHBBoxCoder', 'tblr2bboxes', 'bboxes2tblr' + 'DeltaXYWHBBoxCoder', 'tblr2bboxes', 'bboxes2tblr', 'TBLRBBoxCoder' ] \ No newline at end of file diff --git a/mmdet/core/bbox/assigners/effective_area_assigner.py b/mmdet/core/bbox/assigners/effective_area_assigner.py index bf2dc22f1ef..7ee495f5608 100644 --- a/mmdet/core/bbox/assigners/effective_area_assigner.py +++ b/mmdet/core/bbox/assigners/effective_area_assigner.py @@ -1,10 +1,69 @@ import torch -from ..geometry import bbox_overlaps, bboxes_area, is_located_in, scale_boxes from .assign_result import AssignResult from .base_assigner import BaseAssigner +from ..iou_calculators import build_iou_calculator +from ..registry import BBOX_ASSIGNERS +def scale_boxes(bboxes, scale): + """Expand an array of boxes by a given scale. + Args: + bboxes (Tensor): shape (m, 4) + scale (float): the scale factor of bboxes + + Returns: + (Tensor): shape (m, 4) scaled bboxes + """ + w_half = (bboxes[:, 2] - bboxes[:, 0]) * .5 + h_half = (bboxes[:, 3] - bboxes[:, 1]) * .5 + x_c = (bboxes[:, 2] + bboxes[:, 0]) * .5 + y_c = (bboxes[:, 3] + bboxes[:, 1]) * .5 + + w_half *= scale + h_half *= scale + + boxes_exp = torch.zeros_like(bboxes) + boxes_exp[:, 0] = x_c - w_half + boxes_exp[:, 2] = x_c + w_half + boxes_exp[:, 1] = y_c - h_half + boxes_exp[:, 3] = y_c + h_half + return boxes_exp + + +def is_located_in(points, bboxes, is_aligned=False): + """ is center a locates in box b + Then we compute the area of intersect between box_a and box_b. + Args: + points: (tensor) bounding boxes, Shape: [m,2]. + bboxes: (tensor) bounding boxes, Shape: [n,4]. + If is_aligned is ``True``, then m mush be equal to n + Return: + (tensor) intersection area, Shape: [m, n]. If is_aligned ``True``, + then shape = [m] + """ + if not is_aligned: + return (points[:, 0].unsqueeze(1) > bboxes[:, 0].unsqueeze(0)) & \ + (points[:, 0].unsqueeze(1) < bboxes[:, 2].unsqueeze(0)) & \ + (points[:, 1].unsqueeze(1) > bboxes[:, 1].unsqueeze(0)) & \ + (points[:, 1].unsqueeze(1) < bboxes[:, 3].unsqueeze(0)) + else: + return (points[:, 0] > bboxes[:, 0]) & \ + (points[:, 0] < bboxes[:, 2]) & \ + (points[:, 1] > bboxes[:, 1]) & \ + (points[:, 1] < bboxes[:, 3]) + + +def bboxes_area(bboxes): + """Compute the area of an array of boxes.""" + w = (bboxes[:, 2] - bboxes[:, 0]) + h = (bboxes[:, 3] - bboxes[:, 1]) + areas = w * h + + return areas + + +@BBOX_ASSIGNERS.register_module class EffectiveAreaAssigner(BaseAssigner): """Assign a corresponding gt bbox or background to each bbox. @@ -30,11 +89,13 @@ def __init__(self, pos_area_thr, neg_area_thr, min_pos_iof=1e-2, - ignore_gt_area_thr=0.5): + ignore_gt_area_thr=0.5, + iou_calculator=dict(type='BboxOverlaps2D')): self.pos_area_thr = pos_area_thr self.neg_area_thr = neg_area_thr self.min_pos_iof = min_pos_iof self.ignore_gt_area_thr = ignore_gt_area_thr + self.iou_calculator = build_iou_calculator(iou_calculator) def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): """Assign gt to bboxes. @@ -49,7 +110,7 @@ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are labelled as `ignored`, e.g., crowd boxes in COCO. - gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). + gt_labels (Tensor, optional): Label of gt_bboxes, shape (num_gt, ). Returns: :obj:`AssignResult`: The assign result. @@ -61,11 +122,14 @@ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): # constructing effective gt areas gt_eff = scale_boxes(gt_bboxes, self.pos_area_thr) # effective bboxes, i.e. center 0.2 part - bbox_centers = (bboxes[:, 2:4] + bboxes[:, 0:2] + 1) / 2 + bbox_centers = (bboxes[:, 2:4] + bboxes[:, 0:2]) / 2 is_bbox_in_gt = is_located_in(bbox_centers, gt_bboxes) # the center points lie within the gt boxes - bbox_and_gt_eff_overlaps = bbox_overlaps(bboxes, gt_eff, mode='iof') + # Only calculate bbox and gt_eff IoF. This enables small prior bboxes + # to match large gts + bbox_and_gt_eff_overlaps = self.iou_calculator(bboxes, gt_eff, + mode='iof') is_bbox_in_gt_eff = is_bbox_in_gt & ( bbox_and_gt_eff_overlaps > self.min_pos_iof) # shape (n, k) @@ -73,8 +137,9 @@ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): # constructing ignored gt areas gt_ignore = scale_boxes(gt_bboxes, self.neg_area_thr) - is_bbox_in_gt_ignore = ( - bbox_overlaps(bboxes, gt_ignore, mode='iof') > self.min_pos_iof) + is_bbox_in_gt_ignore = (self.iou_calculator(bboxes, + gt_ignore, mode='iof') + > self.min_pos_iof) is_bbox_in_gt_ignore &= (~is_bbox_in_gt_eff) # rule out center effective pixels @@ -95,8 +160,7 @@ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): is_bbox_in_ignored_gts = is_bbox_in_ignored_gts.any(dim=1) assigned_gt_inds[is_bbox_in_ignored_gts] = -1 - num_bboxes, num_gts = (is_bbox_in_gt_eff.size(0), - is_bbox_in_gt_eff.size(1)) + num_bboxes, num_gts = is_bbox_in_gt_eff.shape if gt_labels is not None: assigned_labels = assigned_gt_inds.new_zeros((num_bboxes, )) pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze() @@ -114,7 +178,6 @@ def assign_one_hot_gt_indices(self, is_bbox_in_gt_ignore, gt_priority=None): """Assign only one gt index to each prior box - (smaller gt has higher priority) Args: @@ -124,7 +187,7 @@ def assign_one_hot_gt_indices(self, is_bbox_in_gt_ignore: shape [num_prior, num_gt]. bool tensor indicating the bbox center is in the ignored area of a gt (e.g. 0.2-0.5) - gt_labels: shape [num_gt]. gt labels (0-81 for COCO) + gt_labels: shape [num_gt]. gt labels (0-80 for COCO) gt_priority: shape [num_gt]. gt priorities. The gt with a higher priority is more likely to be assigned to the bbox when the bbox match with multiple gts @@ -132,11 +195,12 @@ def assign_one_hot_gt_indices(self, Returns: :obj:`AssignResult`: The assign result. """ - num_bboxes, num_gts =\ - is_bbox_in_gt_eff.size(0), is_bbox_in_gt_eff.size(1) + num_bboxes, num_gts = is_bbox_in_gt_eff.shape + if gt_priority is None: gt_priority = torch.arange(num_gts).to(is_bbox_in_gt_eff.device) # the bigger, the more preferable to be assigned + # the assigned inds are by default 0 (background) assigned_gt_inds = is_bbox_in_gt_eff.new_full((num_bboxes, ), 0, dtype=torch.long) @@ -148,20 +212,26 @@ def assign_one_hot_gt_indices(self, if is_bbox_in_gt_eff.sum() == 0: # No gt match return assigned_gt_inds - bbox_priority = is_bbox_in_gt_eff.new_full((num_bboxes, num_gts), + # The priority of each prior box and gt pair. If one prior box is + # matched bo multiple gts. Only the pair with the highest priority + # is saved + pair_priority = is_bbox_in_gt_eff.new_full((num_bboxes, num_gts), -1, dtype=torch.long) # Each bbox could match with multiple gts. - # The following codes deal with this + # The following codes deal with this situation + + # Whether a bbox match a gt, bool tensor, shape [num_match, num_gt] matched_bbox_and_gt_correspondence = is_bbox_in_gt_eff[inds_of_match] - # shape [nmatch, k] + # The matched gt index of each positive bbox. Length >= num_match, + # since one bbox could match multiple gts matched_bbox_gt_inds =\ torch.nonzero(matched_bbox_and_gt_correspondence)[:, 1] - # the matched gt index of each positive bbox. shape [nmatch] - bbox_priority[is_bbox_in_gt_eff] = gt_priority[matched_bbox_gt_inds] - _, argmax_priority = bbox_priority[inds_of_match].max(dim=1) - # the maximum shape [nmatch] + # Assign priority to each bbox-gt pair. + pair_priority[is_bbox_in_gt_eff] = gt_priority[matched_bbox_gt_inds] + _, argmax_priority = pair_priority[inds_of_match].max(dim=1) + # the maximum shape [num_match] # effective indices - assigned_gt_inds[inds_of_match] = argmax_priority + 1 + assigned_gt_inds[inds_of_match] = argmax_priority + 1 # 1-based return assigned_gt_inds diff --git a/mmdet/core/bbox/coder/__init__.py b/mmdet/core/bbox/coder/__init__.py index cc8b969eeb0..8a36eabebaa 100644 --- a/mmdet/core/bbox/coder/__init__.py +++ b/mmdet/core/bbox/coder/__init__.py @@ -1,5 +1,7 @@ from .base_bbox_coder import BaseBBoxCoder from .delta_xywh_bbox_coder import DeltaXYWHBBoxCoder from .pseudo_bbox_coder import PseudoBBoxCoder +from .tblr_bbox_coder import TBLRBBoxCoder -__all__ = ['BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder'] +__all__ = ['BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder', + 'TBLRBBoxCoder'] diff --git a/mmdet/core/bbox/coder/tblr_bbox_coder.py b/mmdet/core/bbox/coder/tblr_bbox_coder.py index d906adca610..700992bb647 100644 --- a/mmdet/core/bbox/coder/tblr_bbox_coder.py +++ b/mmdet/core/bbox/coder/tblr_bbox_coder.py @@ -1,175 +1,114 @@ -import numpy as np import torch -from ..registry import BBOX_CODERS from .base_bbox_coder import BaseBBoxCoder +from ..registry import BBOX_CODERS @BBOX_CODERS.register_module -class DeltaXYWHBBoxCoder(BaseBBoxCoder): - """Delta XYWH BBox coder +class TBLRBBoxCoder(BaseBBoxCoder): + """TBLR BBox coder - Following the practice in R-CNN [1]_, this coder encodes bbox (x1, y1, x2, - y2) into delta (dx, dy, dw, dh) and decodes delta (dx, dy, dw, dh) - back to original bbox (x1, y1, x2, y2). + Following the practice in FSAF [1]_, this coder encodes bbox (x1, y1, x2, + y2) into (top, bottom, left, right) and decode it back to the original. References: - .. [1] https://arxiv.org/abs/1311.2524 + .. [1] https://arxiv.org/abs/1903.00621 Args: - target_means (Sequence[float]): denormalizing means of target for - delta coordinates - target_stds (Sequence[float]): denormalizing standard deviation of + normalizer (Sequence[float] | float): denormalizing standard deviation of target for delta coordinates """ def __init__(self, - target_means=(0., 0., 0., 0.), - target_stds=(1., 1., 1., 1.)): + normalizer=1.0): super(BaseBBoxCoder, self).__init__() - self.means = target_means - self.stds = target_stds + self.normalizer = normalizer def encode(self, bboxes, gt_bboxes): assert bboxes.size(0) == gt_bboxes.size(0) assert bboxes.size(-1) == gt_bboxes.size(-1) == 4 - encoded_bboxes = bbox2delta(bboxes, gt_bboxes, self.means, self.stds) + encoded_bboxes = bboxes2tblr(bboxes, gt_bboxes, self.normalizer) return encoded_bboxes def decode(self, bboxes, pred_bboxes, - max_shape=None, - wh_ratio_clip=16 / 1000): + max_shape=None): assert pred_bboxes.size(0) == bboxes.size(0) - decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, self.stds, - max_shape, wh_ratio_clip) + decoded_bboxes = tblr2bboxes(bboxes, pred_bboxes, self.normalizer, + max_shape) return decoded_bboxes -def bbox2delta(proposals, gt, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.)): - """Compute deltas of proposals w.r.t. gt. - - We usually compute the deltas of x, y, w, h of proposals w.r.t ground - truth bboxes to get regression target. - This is the inverse function of `delta2bbox()` +def bboxes2tblr(priors, gt, normalizer=1.0): + """Encode ground truth boxes Args: - proposals (Tensor): Boxes to be transformed, shape (N, ..., 4) - gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4) - means (Sequence[float]): Denormalizing means for delta coordinates - stds (Sequence[float]): Denormalizing standard deviation for delta - coordinates - - Returns: - Tensor: deltas with shape (N, 4), where columns represent dx, dy, - dw, dh. - + priors (FloatTensor): Prior boxes in point form + Shape: [num_proposals,4]. + gt (FloatTensor): Coords of ground truth for each prior in point-form + Shape: [num_proposals, 4]. + normalizer (list | float): normalization parameter of + encoded boxes. If it is a FloatTensor, it has to have length = 4 + + Return: + encoded boxes (FloatTensor), Shape: [num_proposals, 4] """ - assert proposals.size() == gt.size() - - proposals = proposals.float() - gt = gt.float() - px = (proposals[..., 0] + proposals[..., 2]) * 0.5 - py = (proposals[..., 1] + proposals[..., 3]) * 0.5 - pw = proposals[..., 2] - proposals[..., 0] - ph = proposals[..., 3] - proposals[..., 1] - - gx = (gt[..., 0] + gt[..., 2]) * 0.5 - gy = (gt[..., 1] + gt[..., 3]) * 0.5 - gw = gt[..., 2] - gt[..., 0] - gh = gt[..., 3] - gt[..., 1] - dx = (gx - px) / pw - dy = (gy - py) / ph - dw = torch.log(gw / pw) - dh = torch.log(gh / ph) - deltas = torch.stack([dx, dy, dw, dh], dim=-1) - - means = deltas.new_tensor(means).unsqueeze(0) - stds = deltas.new_tensor(stds).unsqueeze(0) - deltas = deltas.sub_(means).div_(stds) - - return deltas - - -def delta2bbox(rois, - deltas, - means=(0., 0., 0., 0.), - stds=(1., 1., 1., 1.), - max_shape=None, - wh_ratio_clip=16 / 1000): - """Apply deltas to shift/scale base boxes. - - Typically the rois are anchor or proposed bounding boxes and the deltas are - network outputs used to shift/scale those boxes. - This is the inverse function of `bbox2delta()` + # dist b/t match center and prior's center + if not isinstance(normalizer, float): + normalizer = torch.tensor(normalizer).to(priors.device) + assert len(normalizer) == 4, 'Normalizer must have length = 4' + prior_centers = (priors[:, 0:2] + priors[:, 2:4]) / 2 + wh = priors[:, 2:4] - priors[:, 0:2] + + xmin, ymin, xmax, ymax = gt.split(1, dim=1) + top = prior_centers[:, 1].unsqueeze(1) - ymin + bottom = ymax - prior_centers[:, 1].unsqueeze(1) + left = prior_centers[:, 0].unsqueeze(1) - xmin + right = xmax - prior_centers[:, 0].unsqueeze(1) + loc = torch.cat((top, bottom, left, right), dim=1) + w, h = torch.split(wh, 1, dim=1) + loc[:, :2] /= h + # convert them to the coordinate on the featuremap: 0 -fm_size + loc[:, 2:] /= w + return loc / normalizer + + +def tblr2bboxes(priors, tblr, normalizer=1.0, max_shape=None): + """Decode tblr outputs to prediction boxes Args: - rois (Tensor): Boxes to be transformed. Has shape (N, 4) - deltas (Tensor): Encoded offsets with respect to each roi. - Has shape (N, 4 * num_classes). Note N = num_anchors * W * H when - rois is a grid of anchors. Offset encoding follows [1]_. - means (Sequence[float]): Denormalizing means for delta coordinates - stds (Sequence[float]): Denormalizing standard deviation for delta - coordinates - max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W) - wh_ratio_clip (float): Maximum aspect ratio for boxes. - - Returns: - Tensor: Boxes with shape (N, 4), where columns represent - tl_x, tl_y, br_x, br_y. - - References: - .. [1] https://arxiv.org/abs/1311.2524 - - Example: - >>> rois = torch.Tensor([[ 0., 0., 1., 1.], - >>> [ 0., 0., 1., 1.], - >>> [ 0., 0., 1., 1.], - >>> [ 5., 5., 5., 5.]]) - >>> deltas = torch.Tensor([[ 0., 0., 0., 0.], - >>> [ 1., 1., 1., 1.], - >>> [ 0., 0., 2., -1.], - >>> [ 0.7, -1.9, -0.5, 0.3]]) - >>> delta2bbox(rois, deltas, max_shape=(32, 32)) - tensor([[0.0000, 0.0000, 1.0000, 1.0000], - [0.1409, 0.1409, 2.8591, 2.8591], - [0.0000, 0.3161, 4.1945, 0.6839], - [5.0000, 5.0000, 5.0000, 5.0000]]) + priors (FloatTensor): Prior boxes in point form + Shape: [n,4]. + tblr (FloatTensor): Coords of network output in tblr form + Shape: [n, 4]. + normalizer (list | float): normalization parameter of encoded boxes + max_shape (tuple): Shape of the image. + + Return: + encoded boxes (FloatTensor), Shape: [n, 4] """ - means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4) - stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4) - denorm_deltas = deltas * stds + means - dx = denorm_deltas[:, 0::4] - dy = denorm_deltas[:, 1::4] - dw = denorm_deltas[:, 2::4] - dh = denorm_deltas[:, 3::4] - max_ratio = np.abs(np.log(wh_ratio_clip)) - dw = dw.clamp(min=-max_ratio, max=max_ratio) - dh = dh.clamp(min=-max_ratio, max=max_ratio) - # Compute center of each roi - px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx) - py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy) - # Compute width/height of each roi - pw = (rois[:, 2] - rois[:, 0]).unsqueeze(1).expand_as(dw) - ph = (rois[:, 3] - rois[:, 1]).unsqueeze(1).expand_as(dh) - # Use exp(network energy) to enlarge/shrink each roi - gw = pw * dw.exp() - gh = ph * dh.exp() - # Use network energy to shift the center of each roi - gx = torch.addcmul(px, 1, pw, dx) # gx = px + pw * dx - gy = torch.addcmul(py, 1, ph, dy) # gy = py + ph * dy - # Convert center-xy/width/height to top-left, bottom-right - x1 = gx - gw * 0.5 - y1 = gy - gh * 0.5 - x2 = gx + gw * 0.5 - y2 = gy + gh * 0.5 + if not isinstance(normalizer, float): + normalizer = torch.tensor(normalizer).to(priors.device) + assert len(normalizer) == 4, 'Normalizer must have length = 4' + + loc_decode = tblr * normalizer + prior_centers = (priors[:, 0:2] + priors[:, 2:4]) / 2 + wh = priors[:, 2:4] - priors[:, 0:2] + 1 + w, h = torch.split(wh, 1, dim=1) + loc_decode[:, :2] *= h + loc_decode[:, 2:] *= w + top, bottom, left, right = loc_decode.split(1, dim=1) + xmin = prior_centers[:, 0].unsqueeze(1) - left + xmax = prior_centers[:, 0].unsqueeze(1) + right + ymin = prior_centers[:, 1].unsqueeze(1) - top + ymax = prior_centers[:, 1].unsqueeze(1) + bottom + boxes = torch.cat((xmin, ymin, xmax, ymax), dim=1) if max_shape is not None: - x1 = x1.clamp(min=0, max=max_shape[1]) - y1 = y1.clamp(min=0, max=max_shape[0]) - x2 = x2.clamp(min=0, max=max_shape[1]) - y2 = y2.clamp(min=0, max=max_shape[0]) - bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas) - return bboxes + boxes[:, 0].clamp_(min=0, max=max_shape[1] - 1) + boxes[:, 1].clamp_(min=0, max=max_shape[0] - 1) + boxes[:, 2].clamp_(min=0, max=max_shape[1] - 1) + boxes[:, 3].clamp_(min=0, max=max_shape[0] - 1) + return boxes diff --git a/mmdet/core/bbox/iou_calculators/iou2d_calculator.py b/mmdet/core/bbox/iou_calculators/iou2d_calculator.py index e6f684e06d2..3b788f9ccb2 100644 --- a/mmdet/core/bbox/iou_calculators/iou2d_calculator.py +++ b/mmdet/core/bbox/iou_calculators/iou2d_calculator.py @@ -103,59 +103,3 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False): return ious - -def scale_boxes(bboxes, scale): - """Expand an array of boxes by a given scale. - Args: - bboxes (Tensor): shape (m, 4) - scale (float): the scale factor of bboxes - - Returns: - (Tensor): shape (m, 4) scaled bboxes - """ - w_half = (bboxes[:, 2] - bboxes[:, 0] + 1) * .5 - h_half = (bboxes[:, 3] - bboxes[:, 1] + 1) * .5 - x_c = (bboxes[:, 2] + bboxes[:, 0] + 1) * .5 - y_c = (bboxes[:, 3] + bboxes[:, 1] + 1) * .5 - - w_half *= scale - h_half *= scale - - boxes_exp = torch.zeros_like(bboxes) - boxes_exp[:, 0] = x_c - w_half - boxes_exp[:, 2] = x_c + w_half - 1 - boxes_exp[:, 1] = y_c - h_half - boxes_exp[:, 3] = y_c + h_half - 1 - return boxes_exp - - -def is_located_in(points, bboxes, is_aligned=False): - """ is center a locates in box b - Then we compute the area of intersect between box_a and box_b. - Args: - points: (tensor) bounding boxes, Shape: [m,2]. - bboxes: (tensor) bounding boxes, Shape: [n,4]. - If is_aligned is ``True``, then m mush be equal to n - Return: - (tensor) intersection area, Shape: [m, n]. If is_aligned ``True``, - then shape = [m] - """ - if not is_aligned: - return (points[:, 0].unsqueeze(1) > bboxes[:, 0].unsqueeze(0)) & \ - (points[:, 0].unsqueeze(1) < bboxes[:, 2].unsqueeze(0)) & \ - (points[:, 1].unsqueeze(1) > bboxes[:, 1].unsqueeze(0)) & \ - (points[:, 1].unsqueeze(1) < bboxes[:, 3].unsqueeze(0)) - else: - return (points[:, 0] > bboxes[:, 0]) & \ - (points[:, 0] < bboxes[:, 2]) & \ - (points[:, 1] > bboxes[:, 1]) & \ - (points[:, 1] < bboxes[:, 3]) - - -def bboxes_area(bboxes): - """Compute the area of an array of boxes.""" - w = (bboxes[:, 2] - bboxes[:, 0] + 1) - h = (bboxes[:, 3] - bboxes[:, 1] + 1) - areas = w * h - - return areas From 827f16471cd0486b9a9e5594655f9b3396d6d361 Mon Sep 17 00:00:00 2001 From: wangxinjiang Date: Thu, 23 Apr 2020 16:44:43 +0800 Subject: [PATCH 06/11] added fsaf_head and config --- configs/fsaf/fsaf_r50_fpn_1x_coco.py | 55 ++-- .../models/anchor_heads/anchor_tblr_target.py | 192 ----------- mmdet/models/anchor_heads/fsaf_head.py | 306 ++++-------------- 3 files changed, 96 insertions(+), 457 deletions(-) delete mode 100644 mmdet/models/anchor_heads/anchor_tblr_target.py diff --git a/configs/fsaf/fsaf_r50_fpn_1x_coco.py b/configs/fsaf/fsaf_r50_fpn_1x_coco.py index 637967f9ea9..d1322c7d083 100644 --- a/configs/fsaf/fsaf_r50_fpn_1x_coco.py +++ b/configs/fsaf/fsaf_r50_fpn_1x_coco.py @@ -1,36 +1,25 @@ -_base_ = [ - '../_base_/datasets/coco_detection.py', - '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' -] +_base_ = '../retinanet/retinanet_r50_fpn_1x_coco.py' # model settings model = dict( type='FSAF', - pretrained='torchvision://resnet50', - backbone=dict( - type='ResNet', - depth=50, - num_stages=4, - out_indices=(0, 1, 2, 3), - frozen_stages=1, - style='pytorch'), - neck=dict( - type='FPN', - in_channels=[256, 512, 1024, 2048], - out_channels=256, - start_level=1, - add_extra_convs=True, - num_outs=5), bbox_head=dict( type='FSAFHead', - num_classes=81, + num_classes=80, in_channels=256, stacked_convs=4, feat_channels=256, - octave_base_scale=1, - scales_per_octave=1, - anchor_ratios=[1.0], - anchor_strides=[8, 16, 32, 64, 128], - target_normalizer=1.0, + reg_decoded_bbox=True, + anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=1, + scales_per_octave=1, + ratios=[1.0], + strides=[8, 16, 32, 64, 128], + center_offset=0.5), + bbox_coder=dict( + _delete_=True, + type='TBLRBBoxCoder', + normalizer=1.0), loss_cls=dict( type='FocalLoss', use_sigmoid=True, @@ -38,13 +27,17 @@ alpha=0.25, loss_weight=1.0, reduction='none'), - loss_bbox=dict(type='IoULossTBLR', - eps=1e-6, - loss_weight=1.0, - reduction='none'))) + loss_bbox=dict( + _delete_=True, + type='IoULoss', + eps=1e-6, + loss_weight=1.0, + reduction='none'))) + # training and testing settings train_cfg = dict( assigner=dict( + _delete_=True, type='EffectiveAreaAssigner', pos_area_thr=0.2, neg_area_thr=0.2, @@ -53,4 +46,6 @@ pos_weight=-1, debug=False) optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) -optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2)) \ No newline at end of file +optimizer_config = dict(_delete_=True, + grad_clip=dict(max_norm=10, norm_type=2)) +total_epochs = 12 diff --git a/mmdet/models/anchor_heads/anchor_tblr_target.py b/mmdet/models/anchor_heads/anchor_tblr_target.py deleted file mode 100644 index f12acd8fa65..00000000000 --- a/mmdet/models/anchor_heads/anchor_tblr_target.py +++ /dev/null @@ -1,192 +0,0 @@ -import torch - -from mmdet.core.bbox import (PseudoSampler, assign_and_sample, bboxes2tblr, - build_assigner) -from mmdet.core.utils import multi_apply - - -def anchor_tblr_target(anchor_list, - valid_flag_list, - gt_bboxes_list, - img_metas, - target_normalizer, - cfg, - gt_bboxes_ignore_list=None, - gt_labels_list=None, - sampling=True, - unmap_outputs=True): - """Compute regression and classification targets for anchors. - - Args: - anchor_list (list[list]): Multi level anchors of each image. - valid_flag_list (list[list]): Multi level valid flags of each image. - gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. - img_metas (list[dict]): Meta info of each image. - target_means (Iterable): Mean value of regression targets. - target_normalizer (float): Std value of regression targets. - cfg (dict): RPN train configs. - - Returns: - tuple - """ - num_imgs = len(img_metas) - assert len(anchor_list) == len(valid_flag_list) == num_imgs - - # anchor number of multi levels - num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] - # concat all level anchors and flags to a single tensor - for i in range(num_imgs): - assert len(anchor_list[i]) == len(valid_flag_list[i]) - anchor_list[i] = torch.cat(anchor_list[i]) - valid_flag_list[i] = torch.cat(valid_flag_list[i]) - - # compute targets for each image - if gt_bboxes_ignore_list is None: - gt_bboxes_ignore_list = [None for _ in range(num_imgs)] - if gt_labels_list is None: - gt_labels_list = [None for _ in range(num_imgs)] - (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, - pos_inds_list, neg_inds_list, pos_assigned_gt_inds) = multi_apply( - anchor_target_single, - anchor_list, - valid_flag_list, - gt_bboxes_list, - gt_bboxes_ignore_list, - gt_labels_list, - img_metas, - target_normalizer=target_normalizer, - cfg=cfg, - sampling=sampling, - unmap_outputs=unmap_outputs) - # no valid anchors - if any([labels is None for labels in all_labels]): - return None - # sampled anchors of all images - num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) - num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) - # split targets to a list w.r.t. multiple levels - labels_list = images_to_levels(all_labels, num_level_anchors) - label_weights_list = images_to_levels(all_label_weights, num_level_anchors) - bbox_targets_list = images_to_levels(all_bbox_targets, num_level_anchors) - bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors) - pos_assigned_gt_inds_list = images_to_levels(pos_assigned_gt_inds, - num_level_anchors) - return (labels_list, label_weights_list, bbox_targets_list, - bbox_weights_list, num_total_pos, num_total_neg, - pos_assigned_gt_inds_list) - - -def images_to_levels(target, num_level_anchors): - """Convert targets by image to targets by feature level. - - [target_img0, target_img1] -> [target_level0, target_level1, ...] - """ - target = torch.stack(target, 0) - level_targets = [] - start = 0 - for n in num_level_anchors: - end = start + n - level_targets.append(target[:, start:end].squeeze(0)) - start = end - return level_targets - - -def anchor_target_single(flat_anchors, - valid_flags, - gt_bboxes, - gt_bboxes_ignore, - gt_labels, - img_meta, - target_normalizer, - cfg, - sampling=True, - unmap_outputs=True): - inside_flags = anchor_inside_flags(flat_anchors, valid_flags, - img_meta['img_shape'][:2], - cfg.allowed_border) - if not inside_flags.any(): - return (None, ) * 6 - # assign gt and sample anchors - anchors = flat_anchors[inside_flags, :] - - if sampling: - assign_result, sampling_result = assign_and_sample( - anchors, gt_bboxes, gt_bboxes_ignore, None, cfg) - else: - bbox_assigner = build_assigner(cfg.assigner) - assign_result = bbox_assigner.assign(anchors, gt_bboxes, - gt_bboxes_ignore, gt_labels) - bbox_sampler = PseudoSampler() - sampling_result = bbox_sampler.sample(assign_result, anchors, - gt_bboxes) - - num_valid_anchors = anchors.shape[0] - bbox_targets = torch.zeros_like(anchors) - bbox_weights = torch.zeros_like(anchors) - labels = anchors.new_zeros(num_valid_anchors, dtype=torch.long) - label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) - pos_assigned_gt_inds = anchors.new_full((num_valid_anchors, ), - -1, - dtype=torch.long) - - pos_inds = sampling_result.pos_inds - neg_inds = sampling_result.neg_inds - if len(pos_inds) > 0: - pos_bbox_targets = bboxes2tblr(sampling_result.pos_bboxes, - sampling_result.pos_gt_bboxes, - target_normalizer) - bbox_targets[pos_inds, :] = pos_bbox_targets - bbox_weights[pos_inds, :] = 1.0 - pos_assigned_gt_inds[pos_inds] = sampling_result.pos_assigned_gt_inds - if gt_labels is None: - labels[pos_inds] = 1 - else: - labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] - if cfg.pos_weight <= 0: - label_weights[pos_inds] = 1.0 - else: - label_weights[pos_inds] = cfg.pos_weight - if len(neg_inds) > 0: - label_weights[neg_inds] = 1.0 - - # map up to original set of anchors - if unmap_outputs: - num_total_anchors = flat_anchors.size(0) - labels = unmap(labels, num_total_anchors, inside_flags) - label_weights = unmap(label_weights, num_total_anchors, inside_flags) - bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) - bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) - pos_assigned_gt_inds = unmap( - pos_assigned_gt_inds, num_total_anchors, inside_flags, fill=-1) - - return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, - neg_inds, pos_assigned_gt_inds) - - -def anchor_inside_flags(flat_anchors, - valid_flags, - img_shape, - allowed_border=0): - img_h, img_w = img_shape[:2] - if allowed_border >= 0: - inside_flags = valid_flags & \ - (flat_anchors[:, 0] >= -allowed_border).type(torch.uint8) & \ - (flat_anchors[:, 1] >= -allowed_border).type(torch.uint8) & \ - (flat_anchors[:, 2] < img_w + allowed_border).type(torch.uint8) & \ - (flat_anchors[:, 3] < img_h + allowed_border).type(torch.uint8) - else: - inside_flags = valid_flags - return inside_flags - - -def unmap(data, count, inds, fill=0): - """ Unmap a subset of item (data) back to the original set of items (of - size count) """ - if data.dim() == 1: - ret = data.new_full((count, ), fill) - ret[inds] = data - else: - new_size = (count, ) + data.size()[1:] - ret = data.new_full(new_size, fill) - ret[inds, :] = data - return ret diff --git a/mmdet/models/anchor_heads/fsaf_head.py b/mmdet/models/anchor_heads/fsaf_head.py index 2271499e2c5..deb83694642 100644 --- a/mmdet/models/anchor_heads/fsaf_head.py +++ b/mmdet/models/anchor_heads/fsaf_head.py @@ -1,83 +1,12 @@ import numpy as np import torch -from mmdet.core import (anchor_inside_flags, build_anchor_generator, - build_assigner, build_bbox_coder, build_sampler, - force_fp32, images_to_levels, multi_apply, - multiclass_nms, unmap) - -from ..losses import IoULoss -from ..losses.utils import weight_reduce_loss, weighted_loss -from ..registry import HEADS, LOSSES -from .retina_head import RetinaHead - - -@weighted_loss -def iou_loss_tblr(pred, target, eps=1e-6): - """Calculate the iou loss. - - Get iou loss when both the prediction and targets are - encoded in TBLR format. - - Args: - pred: shape (num_anchor, 4) - target: shape (num_anchor, 4) - eps: the minimum iou threshold +from mmcv.cnn import normal_init - Returns: - loss: shape (num_anchor), IoU loss - """ - xt, xb, xl, xr = torch.split(pred, 1, dim=-1) - - # the ground truth position - gt, gb, gl, gr = torch.split(target, 1, dim=-1) - - # compute the bounding box size - X = (xt + xb) * (xl + xr) # AreaX - G = (gt + gb) * (gl + gr) # AreaG - - # compute the IOU - Ih = torch.min(xt, gt) + torch.min(xb, gb) - Iw = torch.min(xl, gl) + torch.min(xr, gr) - - Inter = Ih * Iw - Union = (X + G - Inter).clamp(min=1) - # minimum area should be 1 - - IoU = Inter / Union - IoU = IoU.squeeze() - ious = IoU.clamp(min=eps) - loss = -ious.log() - return loss - - -@LOSSES.register_module -class IoULossTBLR(IoULoss): - - def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0): - super(IoULossTBLR, self).__init__(eps, reduction, loss_weight) - - def forward(self, - pred, - target, - weight=None, - avg_factor=None, - reduction_override=None, - **kwargs): - if weight is not None and not torch.any(weight > 0): - return (pred * weight).sum() # 0 - assert reduction_override in (None, 'none', 'mean', 'sum') - reduction = ( - reduction_override if reduction_override else self.reduction) - weight = weight.sum(dim=-1) / 4. # iou loss is a scalar! - loss = self.loss_weight * iou_loss_tblr( - pred, - target, - weight, - eps=self.eps, - reduction=reduction, - avg_factor=avg_factor, - **kwargs) - return loss +from mmdet.core import (anchor_inside_flags, force_fp32, images_to_levels, + multi_apply, unmap) +from .retina_head import RetinaHead +from ..losses.utils import weight_reduce_loss +from ..registry import HEADS @HEADS.register_module @@ -125,6 +54,12 @@ def forward_single(self, x): return cls_score, self.relu(bbox_pred) # TBLR encoder only accepts positive bbox_pred + def init_weights(self): + super(FSAFHead, self).init_weights() + normal_init(self.retina_reg, std=0.01, bias=0.25) + # the positive bias in self.retina_reg conv is to prevent predicted \ + # bbox with 0 area + def _get_targets_single(self, flat_anchors, valid_flags, @@ -134,6 +69,15 @@ def _get_targets_single(self, img_meta, label_channels=1, unmap_outputs=True): + """Compute regression and classification targets for anchors in + a single image. + + Most of the codes are the same with the base class + ::obj::`AnchorHead`, except that it also collects and returns + the matched gt index in the image (from 0 to num_gt-1). If the + pixel is not matched to any gt, the corresponding value in + pos_gt_inds is -1. + """ inside_flags = anchor_inside_flags(flat_anchors, valid_flags, img_meta['img_shape'][:2], self.train_cfg.allowed_border) @@ -156,8 +100,8 @@ def _get_targets_single(self, dtype=torch.long) label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) pos_gt_inds = anchors.new_full((num_valid_anchors,), - -1, - dtype=torch.long) + -1, + dtype=torch.long) pos_inds = sampling_result.pos_inds neg_inds = sampling_result.neg_inds @@ -194,86 +138,19 @@ def _get_targets_single(self, pos_gt_inds = unmap( pos_gt_inds, num_total_anchors, inside_flags, fill=-1) - return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds, pos_gt_inds) - def get_targets(self, - anchor_list, - valid_flag_list, - gt_bboxes_list, - img_metas, - gt_bboxes_ignore_list=None, - gt_labels_list=None, - label_channels=1, - unmap_outputs=True): - """Compute regression and classification targets for anchors in - multiple images. - - """ - num_imgs = len(img_metas) - assert len(anchor_list) == len(valid_flag_list) == num_imgs - - # anchor number of multi levels - num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] - # concat all level anchors to a single tensor - concat_anchor_list = [] - concat_valid_flag_list = [] - for i in range(num_imgs): - assert len(anchor_list[i]) == len(valid_flag_list[i]) - concat_anchor_list.append(torch.cat(anchor_list[i])) - concat_valid_flag_list.append(torch.cat(valid_flag_list[i])) - - # compute targets for each image - if gt_bboxes_ignore_list is None: - gt_bboxes_ignore_list = [None for _ in range(num_imgs)] - if gt_labels_list is None: - gt_labels_list = [None for _ in range(num_imgs)] - (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, - pos_inds_list, neg_inds_list, pos_assigned_gt_inds) = multi_apply( - self._get_targets_single, - concat_anchor_list, - concat_valid_flag_list, - gt_bboxes_list, - gt_bboxes_ignore_list, - gt_labels_list, - img_metas, - label_channels=label_channels, - unmap_outputs=unmap_outputs) - # no valid anchors - if any([labels is None for labels in all_labels]): - return None - # sampled anchors of all images - num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) - num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) - # split targets to a list w.r.t. multiple levels - labels_list = images_to_levels(all_labels, num_level_anchors) - label_weights_list = images_to_levels(all_label_weights, - num_level_anchors) - bbox_targets_list = images_to_levels(all_bbox_targets, - num_level_anchors) - bbox_weights_list = images_to_levels(all_bbox_weights, - num_level_anchors) - pos_assigned_gt_inds_list = images_to_levels(pos_assigned_gt_inds, - num_level_anchors) - return (labels_list, label_weights_list, bbox_targets_list, - bbox_weights_list, num_total_pos, num_total_neg, - pos_assigned_gt_inds_list) - - @force_fp32(apply_to=('cls_scores', 'bbox_preds')) - def loss( - self, - cls_scores, - bbox_preds, - gt_bboxes, - gt_labels, - img_metas, - cfg, - gt_bboxes_ignore=None, - ): + def loss(self, + cls_scores, + bbox_preds, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None): featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] - assert len(featmap_sizes) == len(self.anchor_generators) + assert len(featmap_sizes) == len(self.anchor_generator.base_anchors) batch_size = len(gt_bboxes) device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( @@ -284,11 +161,8 @@ def loss( valid_flag_list, gt_bboxes, img_metas, - self.target_normalizer, - cfg, gt_bboxes_ignore_list=gt_bboxes_ignore, - gt_labels_list=gt_labels, - sampling=self.sampling) + gt_labels_list=gt_labels) if cls_reg_targets is None: return None (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, @@ -298,16 +172,25 @@ def loss( num_gts = np.array(list(map(len, gt_labels))) num_total_samples = ( num_total_pos + num_total_neg if self.sampling else num_total_pos) + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + # concat all level anchors and flags to a single tensor + concat_anchor_list = [] + for i in range(len(anchor_list)): + concat_anchor_list.append(torch.cat(anchor_list[i])) + all_anchor_list = images_to_levels(concat_anchor_list, + num_level_anchors) + losses_cls, losses_bbox = multi_apply( self.loss_single, cls_scores, bbox_preds, + all_anchor_list, labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, - num_total_samples=num_total_samples, - cfg=cfg) + num_total_samples=num_total_samples) cum_num_gts = list(np.cumsum(num_gts)) for i, assign in enumerate(pos_assigned_gt_inds_list): for j in range(1, batch_size): @@ -359,7 +242,7 @@ def calculate_accuracy(self, cls_scores, labels_list, pos_inds): def argmax(x): return x.argmax(1) if x.numel() > 0 else -100 - num_correct = sum([(argmax(score) + 1 == label).sum() + num_correct = sum([(argmax(score) == label).sum() for score, label in zip(scores, labels)]) return num_correct.float() / (num_pos + 1e-3) @@ -371,7 +254,7 @@ def collect_loss_level_single(self, cls_loss, reg_loss, cls_loss (tensor): classification loss of each feature map pixel, shape (num_anchor, num_class) reg_loss (tensor): regression loss of each feature map pixel, - shape (num_anchor) + shape (num_anchor, 4) pos_assigned_gt_inds (tensor): shape (num_anchor), indicating which gt the prior is assigned to (-1: no assignment) labels_seq: The rank of labels @@ -379,6 +262,9 @@ def collect_loss_level_single(self, cls_loss, reg_loss, Returns: """ + if len(reg_loss.shape) == 2: # iou loss has shape [num_prior, 4] + reg_loss = reg_loss.sum(dim=-1) + loss = cls_loss.sum(dim=-1) + reg_loss # total loss at each feature map point match = ( @@ -392,7 +278,7 @@ def collect_loss_level_single(self, cls_loss, reg_loss, ]) return losses_, - def reassign_loss_single(self, cls_loss, reg_loss, pos_assigned_gt_inds, + def reassign_loss_single(self, cls_loss, reg_loss, assigned_gt_inds, labels, level, min_levels): """Reassign loss values at each level. @@ -403,7 +289,7 @@ def reassign_loss_single(self, cls_loss, reg_loss, pos_assigned_gt_inds, cls_loss (tensor): shape (num_anchors, num_classes) classification loss reg_loss (tensor): shape (num_anchors) regression loss - pos_assigned_gt_inds (tensor): shape (num_anchors), + assigned_gt_inds (tensor): shape (num_anchors), the gt indices that each positive anchor corresponds to. (-1 if it is a negative one) labels (tensor): shape (num_anchors). Label assigned to each pixel @@ -418,78 +304,28 @@ def reassign_loss_single(self, cls_loss, reg_loss, pos_assigned_gt_inds, reg_loss: shape (num_anchors). Corrected regression loss keep_indices: shape (num_anchors). Indicating final postive anchors """ - - unmatch_gt_inds = torch.nonzero(min_levels != level) - # gts indices that unmatch with the current level - match_gt_inds = torch.nonzero(min_levels == level) - loc_weight = cls_loss.new_ones(cls_loss.size(0)) - cls_weight = cls_loss.new_ones(cls_loss.size(0), cls_loss.size(1)) - zeroing_indices = (pos_assigned_gt_inds.view( - -1, 1) == unmatch_gt_inds.view(1, -1)).any(dim=-1) - keep_indices = (pos_assigned_gt_inds.view(-1, 1) == match_gt_inds.view( - 1, -1)).any(dim=-1) - loc_weight[zeroing_indices] = 0 - - # Only the weight corresponding to the label is - # zeroed out if not selected - zeroing_labels = labels[zeroing_indices] - 1 - # the original labels assigned to the anchor box - assert (zeroing_labels >= 0).all() - cls_weight[zeroing_indices, zeroing_labels] = 0 + loc_weight = torch.ones_like(reg_loss) + cls_weight = torch.ones_like(cls_loss) + pos_flags = assigned_gt_inds >= 0 # positive pixel flag + pos_indices = torch.nonzero(pos_flags).flatten() + + if pos_flags.any(): # pos pixels exist + pos_assigned_gt_inds = assigned_gt_inds[pos_flags] + zeroing_indices = (min_levels[pos_assigned_gt_inds] != level) + neg_indices = pos_indices[zeroing_indices] + + if neg_indices.numel(): + pos_flags[neg_indices] = 0 + loc_weight[neg_indices] = 0 + # Only the weight corresponding to the label is + # zeroed out if not selected + zeroing_labels = labels[neg_indices] + # the original labels assigned to the anchor box + assert (zeroing_labels >= 0).all() + cls_weight[neg_indices, zeroing_labels] = 0 # weighted loss for both cls and reg loss cls_loss = weight_reduce_loss(cls_loss, cls_weight, reduction='sum') reg_loss = weight_reduce_loss(reg_loss, loc_weight, reduction='sum') - return cls_loss, reg_loss, keep_indices - - def get_bboxes_single(self, - cls_score_list, - bbox_pred_list, - mlvl_anchors, - img_shape, - scale_factor, - cfg, - rescale=False): - """ - Transform outputs for a single batch item into labeled boxes. - """ - assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) - mlvl_bboxes = [] - mlvl_scores = [] - for cls_score, bbox_pred, anchors in zip(cls_score_list, - bbox_pred_list, mlvl_anchors): - assert cls_score.size()[-2:] == bbox_pred.size()[-2:] - cls_score = cls_score.permute(1, 2, - 0).reshape(-1, self.cls_out_channels) - if self.use_sigmoid_cls: - scores = cls_score.sigmoid() - else: - scores = cls_score.softmax(-1) - bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) - nms_pre = cfg.get('nms_pre', -1) - if nms_pre > 0 and scores.shape[0] > nms_pre: - # Get maximum scores for foreground classes. - if self.use_sigmoid_cls: - max_scores, _ = scores.max(dim=1) - else: - max_scores, _ = scores[:, 1:].max(dim=1) - _, topk_inds = max_scores.topk(nms_pre) - anchors = anchors[topk_inds, :] - bbox_pred = bbox_pred[topk_inds, :] - scores = scores[topk_inds, :] - bboxes = tblr2bboxes(anchors, bbox_pred, self.target_normalizer, - img_shape) - mlvl_bboxes.append(bboxes) - mlvl_scores.append(scores) - mlvl_bboxes = torch.cat(mlvl_bboxes) - if rescale: - mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) - mlvl_scores = torch.cat(mlvl_scores) - if self.use_sigmoid_cls: - # Add a dummy background class to the front when using sigmoid - padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) - mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) - det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, - cfg.score_thr, cfg.nms, - cfg.max_per_img) - return det_bboxes, det_labels + + return cls_loss, reg_loss, pos_flags From c8e6244d23d4003f5554bafe65380d2edacef907 Mon Sep 17 00:00:00 2001 From: wangxinjiang Date: Thu, 23 Apr 2020 19:19:06 +0800 Subject: [PATCH 07/11] added more fsaf configs --- configs/fsaf/fsaf_r101_fpn_1x.py | 131 -------------------- configs/fsaf/fsaf_r101_fpn_1x_coco.py | 2 + configs/fsaf/fsaf_r50_fpn_1x_coco.py | 5 +- configs/fsaf/fsaf_x101_64x4d_fpn_1x_coco.py | 13 ++ mmdet/core/bbox/__init__.py | 13 +- mmdet/core/bbox/coder/tblr_bbox_coder.py | 10 +- mmdet/core/bbox/transforms.py | 65 ---------- mmdet/models/anchor_heads/fsaf_head.py | 45 +++---- 8 files changed, 53 insertions(+), 231 deletions(-) delete mode 100644 configs/fsaf/fsaf_r101_fpn_1x.py create mode 100644 configs/fsaf/fsaf_r101_fpn_1x_coco.py create mode 100644 configs/fsaf/fsaf_x101_64x4d_fpn_1x_coco.py diff --git a/configs/fsaf/fsaf_r101_fpn_1x.py b/configs/fsaf/fsaf_r101_fpn_1x.py deleted file mode 100644 index 720127a5e65..00000000000 --- a/configs/fsaf/fsaf_r101_fpn_1x.py +++ /dev/null @@ -1,131 +0,0 @@ -# model settings -model = dict( - type='RetinaNet', - pretrained='torchvision://resnet101', - backbone=dict( - type='ResNet', - depth=101, - num_stages=4, - out_indices=(0, 1, 2, 3), - frozen_stages=1, - style='pytorch'), - neck=dict( - type='FPN', - in_channels=[256, 512, 1024, 2048], - out_channels=256, - start_level=1, - add_extra_convs=True, - num_outs=5), - bbox_head=dict( - type='FSAFHead', - num_classes=81, - in_channels=256, - stacked_convs=4, - feat_channels=256, - octave_base_scale=1, - scales_per_octave=1, - anchor_ratios=[1.0], - anchor_strides=[8, 16, 32, 64, 128], - target_normalizer=1.0, - loss_cls=dict( - type='FocalLoss', - use_sigmoid=True, - gamma=2.0, - alpha=0.25, - loss_weight=1.0, - reduction='none'), - loss_bbox=dict(type='IoULossTBLR', - eps=1e-6, - loss_weight=1.0, - reduction='none'))) -# training and testing settings -train_cfg = dict( - assigner=dict( - type='EffectiveAreaAssigner', - pos_area_thr=0.2, - neg_area_thr=0.2, - min_pos_iof=0.01), - allowed_border=-1, - pos_weight=-1, - debug=False) -test_cfg = dict( - nms_pre=1000, - min_bbox_size=0, - score_thr=0.05, - nms=dict(type='nms', iou_thr=0.5), - max_per_img=100) -# dataset settings -dataset_type = 'CocoDataset' -data_root = 'data/coco/' -img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) -train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations', with_bbox=True), - dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), - dict(type='RandomFlip', flip_ratio=0.5), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size_divisor=32), - dict(type='DefaultFormatBundle'), - dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), -] -test_pipeline = [ - dict(type='LoadImageFromFile'), - dict( - type='MultiScaleFlipAug', - img_scale=(1333, 800), - flip=False, - transforms=[ - dict(type='Resize', keep_ratio=True), - dict(type='RandomFlip'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size_divisor=32), - dict(type='ImageToTensor', keys=['img']), - dict(type='Collect', keys=['img']), - ]) -] -data = dict( - imgs_per_gpu=2, - workers_per_gpu=2, - train=dict( - type=dataset_type, - ann_file=data_root + 'annotations/instances_train2017.json', - img_prefix=data_root + 'train2017/', - pipeline=train_pipeline), - val=dict( - type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', - pipeline=test_pipeline), - test=dict( - type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', - pipeline=test_pipeline)) - -optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) -optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2)) -# learning policy -lr_config = dict( - policy='step', - warmup='linear', - warmup_iters=500, - warmup_ratio=1.0 / 3, - step=[8, 11]) -checkpoint_config = dict(interval=1) -# yapf:disable -log_config = dict( - interval=50, - hooks=[ - dict(type='TextLoggerHook'), - # dict(type='TensorboardLoggerHook') - ]) -# yapf:enable -# runtime settings -total_epochs = 12 -dist_params = dict(backend='nccl') -log_level = 'INFO' -work_dir = './work_dirs/fsaf_r101_fpn_1x' -load_from = None -resume_from = None -workflow = [('train', 1)] diff --git a/configs/fsaf/fsaf_r101_fpn_1x_coco.py b/configs/fsaf/fsaf_r101_fpn_1x_coco.py new file mode 100644 index 00000000000..95a7ae2de59 --- /dev/null +++ b/configs/fsaf/fsaf_r101_fpn_1x_coco.py @@ -0,0 +1,2 @@ +_base_ = './fsaf_r50_fpn_1x_coco.py' +model = dict(pretrained='torchvision://resnet101', backbone=dict(depth=101)) diff --git a/configs/fsaf/fsaf_r50_fpn_1x_coco.py b/configs/fsaf/fsaf_r50_fpn_1x_coco.py index d1322c7d083..fbc75a45b34 100644 --- a/configs/fsaf/fsaf_r50_fpn_1x_coco.py +++ b/configs/fsaf/fsaf_r50_fpn_1x_coco.py @@ -32,7 +32,8 @@ type='IoULoss', eps=1e-6, loss_weight=1.0, - reduction='none'))) + reduction='none'), + )) # training and testing settings train_cfg = dict( @@ -48,4 +49,4 @@ optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=10, norm_type=2)) -total_epochs = 12 +total_epochs = 13 diff --git a/configs/fsaf/fsaf_x101_64x4d_fpn_1x_coco.py b/configs/fsaf/fsaf_x101_64x4d_fpn_1x_coco.py new file mode 100644 index 00000000000..b966f24969a --- /dev/null +++ b/configs/fsaf/fsaf_x101_64x4d_fpn_1x_coco.py @@ -0,0 +1,13 @@ +_base_ = './fsaf_r50_fpn_1x_coco.py' +model = dict( + pretrained='open-mmlab://resnext101_64x4d', + backbone=dict( + type='ResNeXt', + depth=101, + groups=64, + base_width=4, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + style='pytorch')) diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py index ccb506db8e2..c9bbc48aa75 100644 --- a/mmdet/core/bbox/__init__.py +++ b/mmdet/core/bbox/__init__.py @@ -1,6 +1,5 @@ -from .assigners import AssignResult, BaseAssigner, MaxIoUAssigner -from .builder import ( # isort:skip, avoid recursive imports - build_assigner, build_sampler, build_bbox_coder) +from .assigners import AssignResult, BaseAssigner, MaxIoUAssigner, \ + EffectiveAreaAssigner from .coder import BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder, \ TBLRBBoxCoder from .iou_calculators import BboxOverlaps2D, bbox_overlaps @@ -8,8 +7,10 @@ InstanceBalancedPosSampler, IoUBalancedNegSampler, PseudoSampler, RandomSampler, SamplingResult) from .transforms import (bbox2result, bbox2roi, bbox_flip, bbox_mapping, - bbox_mapping_back, distance2bbox, roi2bbox, - tblr2bboxes, bboxes2tblr) + bbox_mapping_back, distance2bbox, roi2bbox) + +from .builder import ( # isort:skip, avoid recursive imports + build_assigner, build_sampler, build_bbox_coder) __all__ = [ 'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner', @@ -18,5 +19,5 @@ 'SamplingResult', 'build_assigner', 'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder', - 'DeltaXYWHBBoxCoder', 'tblr2bboxes', 'bboxes2tblr', 'TBLRBBoxCoder' + 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'EffectiveAreaAssigner' ] \ No newline at end of file diff --git a/mmdet/core/bbox/coder/tblr_bbox_coder.py b/mmdet/core/bbox/coder/tblr_bbox_coder.py index 700992bb647..74623f97b7d 100644 --- a/mmdet/core/bbox/coder/tblr_bbox_coder.py +++ b/mmdet/core/bbox/coder/tblr_bbox_coder.py @@ -96,7 +96,7 @@ def tblr2bboxes(priors, tblr, normalizer=1.0, max_shape=None): loc_decode = tblr * normalizer prior_centers = (priors[:, 0:2] + priors[:, 2:4]) / 2 - wh = priors[:, 2:4] - priors[:, 0:2] + 1 + wh = priors[:, 2:4] - priors[:, 0:2] w, h = torch.split(wh, 1, dim=1) loc_decode[:, :2] *= h loc_decode[:, 2:] *= w @@ -107,8 +107,8 @@ def tblr2bboxes(priors, tblr, normalizer=1.0, max_shape=None): ymax = prior_centers[:, 1].unsqueeze(1) + bottom boxes = torch.cat((xmin, ymin, xmax, ymax), dim=1) if max_shape is not None: - boxes[:, 0].clamp_(min=0, max=max_shape[1] - 1) - boxes[:, 1].clamp_(min=0, max=max_shape[0] - 1) - boxes[:, 2].clamp_(min=0, max=max_shape[1] - 1) - boxes[:, 3].clamp_(min=0, max=max_shape[0] - 1) + boxes[:, 0].clamp_(min=0, max=max_shape[1]) + boxes[:, 1].clamp_(min=0, max=max_shape[0]) + boxes[:, 2].clamp_(min=0, max=max_shape[1]) + boxes[:, 3].clamp_(min=0, max=max_shape[0]) return boxes diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py index b63554771b7..9ba80f45754 100644 --- a/mmdet/core/bbox/transforms.py +++ b/mmdet/core/bbox/transforms.py @@ -111,68 +111,3 @@ def distance2bbox(points, distance, max_shape=None): x2 = x2.clamp(min=0, max=max_shape[1]) y2 = y2.clamp(min=0, max=max_shape[0]) return torch.stack([x1, y1, x2, y2], -1) - - -def bboxes2tblr(priors, gt, normalizer=1.0): - """Encode ground truth boxes - - Args: - priors (FloatTensor): Prior boxes in point form - Shape: [num_proposals,4]. - gt (FloatTensor): Coords of ground truth for each prior in point-form - Shape: [num_proposals, 4]. - normalizer (float): normalization parameter of encoded boxes - - Return: - encoded boxes (FloatTensor), Shape: [num_proposals, 4] - """ - - # dist b/t match center and prior's center - prior_centers = (priors[:, 0:2] + priors[:, 2:4] + 1) / 2 - wh = priors[:, 2:4] - priors[:, 0:2] + 1 - - xmin, ymin, xmax, ymax = gt.split(1, dim=1) - top = prior_centers[:, 1].unsqueeze(1) - ymin - bottom = ymax - prior_centers[:, 1].unsqueeze(1) + 1 - left = prior_centers[:, 0].unsqueeze(1) - xmin - right = xmax - prior_centers[:, 0].unsqueeze(1) + 1 - loc = torch.cat((top, bottom, left, right), dim=1) - w, h = torch.split(wh, 1, dim=1) - loc[:, :2] /= h - # convert them to the coordinate on the featuremap: 0 -fm_size - loc[:, 2:] /= w - return loc / normalizer - - -def tblr2bboxes(priors, tblr, normalizer=1.0, max_shape=None): - """Decode tblr outputs to prediction boxes - - Args: - priors (FloatTensor): Prior boxes in point form - Shape: [n,4]. - tblr (FloatTensor): Coords of network output in tblr form - Shape: [n, 4]. - normalizer (float): normalization parameter of encoded boxes - max_shape (tuple): Shape of the image. - - Return: - encoded boxes (FloatTensor), Shape: [n, 4] - """ - loc_decode = tblr * normalizer - prior_centers = (priors[:, 0:2] + priors[:, 2:4] + 1) / 2 - wh = priors[:, 2:4] - priors[:, 0:2] + 1 - w, h = torch.split(wh, 1, dim=1) - loc_decode[:, :2] *= h - loc_decode[:, 2:] *= w - top, bottom, left, right = loc_decode.split(1, dim=1) - xmin = prior_centers[:, 0].unsqueeze(1) - left - xmax = prior_centers[:, 0].unsqueeze(1) + right - 1 - ymin = prior_centers[:, 1].unsqueeze(1) - top - ymax = prior_centers[:, 1].unsqueeze(1) + bottom - 1 - boxes = torch.cat((xmin, ymin, xmax, ymax), dim=1) - if max_shape is not None: - boxes[:, 0].clamp_(min=0, max=max_shape[1] - 1) - boxes[:, 1].clamp_(min=0, max=max_shape[0] - 1) - boxes[:, 2].clamp_(min=0, max=max_shape[1] - 1) - boxes[:, 3].clamp_(min=0, max=max_shape[0] - 1) - return boxes diff --git a/mmdet/models/anchor_heads/fsaf_head.py b/mmdet/models/anchor_heads/fsaf_head.py index deb83694642..4a3008ccb36 100644 --- a/mmdet/models/anchor_heads/fsaf_head.py +++ b/mmdet/models/anchor_heads/fsaf_head.py @@ -4,9 +4,10 @@ from mmdet.core import (anchor_inside_flags, force_fp32, images_to_levels, multi_apply, unmap) -from .retina_head import RetinaHead + from ..losses.utils import weight_reduce_loss from ..registry import HEADS +from .retina_head import RetinaHead @HEADS.register_module @@ -112,7 +113,7 @@ def _get_targets_single(self, else: pos_bbox_targets = sampling_result.pos_gt_bboxes bbox_targets[pos_inds, :] = pos_bbox_targets - bbox_weights[pos_inds, :] = 1.0 + bbox_weights[pos_inds, :] = 1. / 4. # avg in tblr dims pos_gt_inds[pos_inds] = sampling_result.pos_assigned_gt_inds if gt_labels is None: # only rpn gives gt_labels as None, this time FG is 1 @@ -216,7 +217,8 @@ def loss(self, list(range(len(losses_cls))), min_levels=argmin) - num_pos = torch.cat(pos_inds, 0).sum().float() + num_pos = torch.cat(pos_inds, 0).sum().float().clamp(min=1e-3) + # clamp to 1e-3 to prevent 0/0 acc = self.calculate_accuracy(cls_scores, labels_list, pos_inds) for i in range(len(losses_cls)): losses_cls[i] /= num_pos @@ -229,7 +231,7 @@ def loss(self, def calculate_accuracy(self, cls_scores, labels_list, pos_inds): with torch.no_grad(): - num_pos = torch.cat(pos_inds, 0).sum().float() + num_pos = torch.cat(pos_inds, 0).sum().float().clamp(min=1e-3) num_class = cls_scores[0].size(1) scores = [ cls.permute(0, 2, 3, 1).reshape(-1, num_class)[pos] @@ -244,10 +246,10 @@ def argmax(x): num_correct = sum([(argmax(score) == label).sum() for score, label in zip(scores, labels)]) - return num_correct.float() / (num_pos + 1e-3) + return num_correct.float() / num_pos def collect_loss_level_single(self, cls_loss, reg_loss, - pos_assigned_gt_inds, labels_seq): + assigned_gt_inds, labels_seq): """Get the average loss in each FPN level w.r.t. each gt label Args: @@ -255,27 +257,26 @@ def collect_loss_level_single(self, cls_loss, reg_loss, shape (num_anchor, num_class) reg_loss (tensor): regression loss of each feature map pixel, shape (num_anchor, 4) - pos_assigned_gt_inds (tensor): shape (num_anchor), indicating + assigned_gt_inds (tensor): shape (num_anchor), indicating which gt the prior is assigned to (-1: no assignment) - labels_seq: The rank of labels + labels_seq: The rank of labels. shape (num_gt) Returns: - + shape: (num_gt), average loss of each gt in this level """ if len(reg_loss.shape) == 2: # iou loss has shape [num_prior, 4] - reg_loss = reg_loss.sum(dim=-1) - - loss = cls_loss.sum(dim=-1) + reg_loss + reg_loss = reg_loss.sum(dim=-1) # sum loss in tblr dims # total loss at each feature map point - match = ( - pos_assigned_gt_inds.reshape(-1).unsqueeze(1) == - labels_seq.unsqueeze(0)) - loss_ceiling = loss.new_zeros(1).squeeze() + 1e6 - # default loss value for a layer where no anchor is positive - losses_ = torch.stack([ - torch.mean(loss[match[:, i]]) - if match[:, i].sum() > 0 else loss_ceiling for i in labels_seq - ]) + if len(cls_loss.shape) == 2: + cls_loss = cls_loss.sum(dim=-1) # sum loss in class dims + loss = cls_loss + reg_loss + assert loss.size(0) == assigned_gt_inds.size(0) + # default loss value is 1e6 for a layer where no anchor is positive + losses_ = loss.new_full(labels_seq.shape, 1e6) + for i, l in enumerate(labels_seq): + match = assigned_gt_inds == l + if match.any(): + losses_[i] = loss[match].mean() return losses_, def reassign_loss_single(self, cls_loss, reg_loss, assigned_gt_inds, @@ -302,7 +303,7 @@ def reassign_loss_single(self, cls_loss, reg_loss, assigned_gt_inds, cls_loss: shape (num_anchors, num_classes). Corrected classification loss reg_loss: shape (num_anchors). Corrected regression loss - keep_indices: shape (num_anchors). Indicating final postive anchors + pos_flags: shape (num_anchors). Indicating final postive anchors """ loc_weight = torch.ones_like(reg_loss) cls_weight = torch.ones_like(cls_loss) From 16e355c1bf259249423fee943e405e4966dd0edf Mon Sep 17 00:00:00 2001 From: wangxinjiang Date: Thu, 23 Apr 2020 19:58:23 +0800 Subject: [PATCH 08/11] fixed the nan bbox_loss by clamping bbox_pred with min=1e-4 --- configs/fsaf/fsaf_r50_fpn_1x_coco.py | 2 +- mmdet/models/anchor_heads/fsaf_head.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/configs/fsaf/fsaf_r50_fpn_1x_coco.py b/configs/fsaf/fsaf_r50_fpn_1x_coco.py index fbc75a45b34..d9196889b2b 100644 --- a/configs/fsaf/fsaf_r50_fpn_1x_coco.py +++ b/configs/fsaf/fsaf_r50_fpn_1x_coco.py @@ -49,4 +49,4 @@ optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=10, norm_type=2)) -total_epochs = 13 +total_epochs = 12 diff --git a/mmdet/models/anchor_heads/fsaf_head.py b/mmdet/models/anchor_heads/fsaf_head.py index 4a3008ccb36..b886c3a185b 100644 --- a/mmdet/models/anchor_heads/fsaf_head.py +++ b/mmdet/models/anchor_heads/fsaf_head.py @@ -181,7 +181,9 @@ def loss(self, concat_anchor_list.append(torch.cat(anchor_list[i])) all_anchor_list = images_to_levels(concat_anchor_list, num_level_anchors) - + for i in range(len(bbox_preds)): + bbox_preds[i].clamp_( + min=1e-4) # avoid 0 area of the predicted bbox losses_cls, losses_bbox = multi_apply( self.loss_single, cls_scores, From fece78dd2cd4e8eb0d99a9b2714c00a12f1959b5 Mon Sep 17 00:00:00 2001 From: wangxinjiang Date: Thu, 23 Apr 2020 21:11:46 +0800 Subject: [PATCH 09/11] fixed the nan bbox_loss by clamping bbox_pred with min=1e-4 --- mmdet/models/anchor_heads/fsaf_head.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmdet/models/anchor_heads/fsaf_head.py b/mmdet/models/anchor_heads/fsaf_head.py index b886c3a185b..803903ad041 100644 --- a/mmdet/models/anchor_heads/fsaf_head.py +++ b/mmdet/models/anchor_heads/fsaf_head.py @@ -182,8 +182,8 @@ def loss(self, all_anchor_list = images_to_levels(concat_anchor_list, num_level_anchors) for i in range(len(bbox_preds)): - bbox_preds[i].clamp_( - min=1e-4) # avoid 0 area of the predicted bbox + bbox_preds[i] = bbox_preds[i].clamp(min=1e-4) + # avoid 0 area of the predicted bbox losses_cls, losses_bbox = multi_apply( self.loss_single, cls_scores, From 4238ca261079b7c52eebfe861d1fc3518de5bc64 Mon Sep 17 00:00:00 2001 From: wangxinjiang Date: Thu, 23 Apr 2020 22:44:12 +0800 Subject: [PATCH 10/11] changed config --- configs/fsaf/fsaf_r50_fpn_1x_coco.py | 1 - 1 file changed, 1 deletion(-) diff --git a/configs/fsaf/fsaf_r50_fpn_1x_coco.py b/configs/fsaf/fsaf_r50_fpn_1x_coco.py index d9196889b2b..e5698eed108 100644 --- a/configs/fsaf/fsaf_r50_fpn_1x_coco.py +++ b/configs/fsaf/fsaf_r50_fpn_1x_coco.py @@ -49,4 +49,3 @@ optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=10, norm_type=2)) -total_epochs = 12 From d127d088257a0d03265e012610a11285c30bbe5e Mon Sep 17 00:00:00 2001 From: wangxinjiang Date: Thu, 23 Apr 2020 23:42:10 +0800 Subject: [PATCH 11/11] formatter --- configs/fsaf/fsaf_r50_fpn_1x_coco.py | 9 +++------ ...k_rcnn_r50_caffe_fpn_detectron2-poly_1x_coco.py | 2 +- .../mask_rcnn_r50_caffe_fpn_detectron2_1x_coco.py | 2 +- demo/inference_demo.ipynb | 2 +- mmdet/core/anchor/__init__.py | 1 - mmdet/core/bbox/__init__.py | 10 +++++----- .../core/bbox/assigners/effective_area_assigner.py | 14 +++++++------- mmdet/core/bbox/coder/__init__.py | 5 +++-- mmdet/core/bbox/coder/tblr_bbox_coder.py | 14 +++++--------- .../core/bbox/iou_calculators/iou2d_calculator.py | 1 - mmdet/models/anchor_heads/anchor_head.py | 2 +- mmdet/models/anchor_heads/fsaf_head.py | 7 +++---- 12 files changed, 30 insertions(+), 39 deletions(-) diff --git a/configs/fsaf/fsaf_r50_fpn_1x_coco.py b/configs/fsaf/fsaf_r50_fpn_1x_coco.py index e5698eed108..d4413020cf1 100644 --- a/configs/fsaf/fsaf_r50_fpn_1x_coco.py +++ b/configs/fsaf/fsaf_r50_fpn_1x_coco.py @@ -16,10 +16,7 @@ ratios=[1.0], strides=[8, 16, 32, 64, 128], center_offset=0.5), - bbox_coder=dict( - _delete_=True, - type='TBLRBBoxCoder', - normalizer=1.0), + bbox_coder=dict(_delete_=True, type='TBLRBBoxCoder', normalizer=1.0), loss_cls=dict( type='FocalLoss', use_sigmoid=True, @@ -47,5 +44,5 @@ pos_weight=-1, debug=False) optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) -optimizer_config = dict(_delete_=True, - grad_clip=dict(max_norm=10, norm_type=2)) +optimizer_config = dict( + _delete_=True, grad_clip=dict(max_norm=10, norm_type=2)) diff --git a/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_detectron2-poly_1x_coco.py b/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_detectron2-poly_1x_coco.py index 93a89d3349d..ed3b34568c8 100644 --- a/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_detectron2-poly_1x_coco.py +++ b/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_detectron2-poly_1x_coco.py @@ -16,7 +16,7 @@ type='Resize', img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), (1333, 768), (1333, 800)], - multiscale_mode="value", + multiscale_mode='value', keep_ratio=True), dict(type='RandomFlip', flip_ratio=0.5), dict(type='Normalize', **img_norm_cfg), diff --git a/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_detectron2_1x_coco.py b/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_detectron2_1x_coco.py index b0a43463cc8..3e46865ef78 100644 --- a/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_detectron2_1x_coco.py +++ b/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_detectron2_1x_coco.py @@ -12,7 +12,7 @@ type='Resize', img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), (1333, 768), (1333, 800)], - multiscale_mode="value", + multiscale_mode='value', keep_ratio=True), dict(type='RandomFlip', flip_ratio=0.5), dict(type='Normalize', **img_norm_cfg), diff --git a/demo/inference_demo.ipynb b/demo/inference_demo.ipynb index fa6d513c5c0..4df4e7c0792 100644 --- a/demo/inference_demo.ipynb +++ b/demo/inference_demo.ipynb @@ -98,4 +98,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/mmdet/core/anchor/__init__.py b/mmdet/core/anchor/__init__.py index 36724c4af2e..d129974d96d 100644 --- a/mmdet/core/anchor/__init__.py +++ b/mmdet/core/anchor/__init__.py @@ -1,6 +1,5 @@ from .anchor_generator import AnchorGenerator, LegacyAnchorGenerator from .builder import build_anchor_generator - from .point_generator import PointGenerator from .registry import ANCHOR_GENERATORS from .utils import anchor_inside_flags, calc_region, images_to_levels diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py index c9bbc48aa75..224efb6ea06 100644 --- a/mmdet/core/bbox/__init__.py +++ b/mmdet/core/bbox/__init__.py @@ -1,7 +1,7 @@ -from .assigners import AssignResult, BaseAssigner, MaxIoUAssigner, \ - EffectiveAreaAssigner -from .coder import BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder, \ - TBLRBBoxCoder +from .assigners import (AssignResult, BaseAssigner, EffectiveAreaAssigner, + MaxIoUAssigner) +from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder, + TBLRBBoxCoder) from .iou_calculators import BboxOverlaps2D, bbox_overlaps from .samplers import (BaseSampler, CombinedSampler, InstanceBalancedPosSampler, IoUBalancedNegSampler, @@ -20,4 +20,4 @@ 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'EffectiveAreaAssigner' -] \ No newline at end of file +] diff --git a/mmdet/core/bbox/assigners/effective_area_assigner.py b/mmdet/core/bbox/assigners/effective_area_assigner.py index 7ee495f5608..1e2daa0647b 100644 --- a/mmdet/core/bbox/assigners/effective_area_assigner.py +++ b/mmdet/core/bbox/assigners/effective_area_assigner.py @@ -1,9 +1,9 @@ import torch -from .assign_result import AssignResult -from .base_assigner import BaseAssigner from ..iou_calculators import build_iou_calculator from ..registry import BBOX_ASSIGNERS +from .assign_result import AssignResult +from .base_assigner import BaseAssigner def scale_boxes(bboxes, scale): @@ -128,8 +128,8 @@ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): # Only calculate bbox and gt_eff IoF. This enables small prior bboxes # to match large gts - bbox_and_gt_eff_overlaps = self.iou_calculator(bboxes, gt_eff, - mode='iof') + bbox_and_gt_eff_overlaps = self.iou_calculator( + bboxes, gt_eff, mode='iof') is_bbox_in_gt_eff = is_bbox_in_gt & ( bbox_and_gt_eff_overlaps > self.min_pos_iof) # shape (n, k) @@ -137,9 +137,9 @@ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): # constructing ignored gt areas gt_ignore = scale_boxes(gt_bboxes, self.neg_area_thr) - is_bbox_in_gt_ignore = (self.iou_calculator(bboxes, - gt_ignore, mode='iof') - > self.min_pos_iof) + is_bbox_in_gt_ignore = ( + self.iou_calculator(bboxes, gt_ignore, mode='iof') > + self.min_pos_iof) is_bbox_in_gt_ignore &= (~is_bbox_in_gt_eff) # rule out center effective pixels diff --git a/mmdet/core/bbox/coder/__init__.py b/mmdet/core/bbox/coder/__init__.py index 8a36eabebaa..eb5f966dc6a 100644 --- a/mmdet/core/bbox/coder/__init__.py +++ b/mmdet/core/bbox/coder/__init__.py @@ -3,5 +3,6 @@ from .pseudo_bbox_coder import PseudoBBoxCoder from .tblr_bbox_coder import TBLRBBoxCoder -__all__ = ['BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder', - 'TBLRBBoxCoder'] +__all__ = [ + 'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder' +] diff --git a/mmdet/core/bbox/coder/tblr_bbox_coder.py b/mmdet/core/bbox/coder/tblr_bbox_coder.py index 74623f97b7d..9800b76721f 100644 --- a/mmdet/core/bbox/coder/tblr_bbox_coder.py +++ b/mmdet/core/bbox/coder/tblr_bbox_coder.py @@ -1,7 +1,7 @@ import torch -from .base_bbox_coder import BaseBBoxCoder from ..registry import BBOX_CODERS +from .base_bbox_coder import BaseBBoxCoder @BBOX_CODERS.register_module @@ -15,12 +15,11 @@ class TBLRBBoxCoder(BaseBBoxCoder): .. [1] https://arxiv.org/abs/1903.00621 Args: - normalizer (Sequence[float] | float): denormalizing standard deviation of - target for delta coordinates + normalizer (Sequence[float] | float): denormalizing standard deviation + of target for delta coordinates """ - def __init__(self, - normalizer=1.0): + def __init__(self, normalizer=1.0): super(BaseBBoxCoder, self).__init__() self.normalizer = normalizer @@ -30,10 +29,7 @@ def encode(self, bboxes, gt_bboxes): encoded_bboxes = bboxes2tblr(bboxes, gt_bboxes, self.normalizer) return encoded_bboxes - def decode(self, - bboxes, - pred_bboxes, - max_shape=None): + def decode(self, bboxes, pred_bboxes, max_shape=None): assert pred_bboxes.size(0) == bboxes.size(0) decoded_bboxes = tblr2bboxes(bboxes, pred_bboxes, self.normalizer, max_shape) diff --git a/mmdet/core/bbox/iou_calculators/iou2d_calculator.py b/mmdet/core/bbox/iou_calculators/iou2d_calculator.py index 3b788f9ccb2..51b8b19dfbd 100644 --- a/mmdet/core/bbox/iou_calculators/iou2d_calculator.py +++ b/mmdet/core/bbox/iou_calculators/iou2d_calculator.py @@ -102,4 +102,3 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False): ious = overlap / (area1[:, None]) return ious - diff --git a/mmdet/models/anchor_heads/anchor_head.py b/mmdet/models/anchor_heads/anchor_head.py index 2d47031426d..6bbf0a2dae6 100644 --- a/mmdet/models/anchor_heads/anchor_head.py +++ b/mmdet/models/anchor_heads/anchor_head.py @@ -335,7 +335,7 @@ def get_targets(self, return (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) \ - + tuple(rest_results) + + tuple(rest_results) def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights, bbox_targets, bbox_weights, num_total_samples): diff --git a/mmdet/models/anchor_heads/fsaf_head.py b/mmdet/models/anchor_heads/fsaf_head.py index 803903ad041..595c6da65da 100644 --- a/mmdet/models/anchor_heads/fsaf_head.py +++ b/mmdet/models/anchor_heads/fsaf_head.py @@ -4,7 +4,6 @@ from mmdet.core import (anchor_inside_flags, force_fp32, images_to_levels, multi_apply, unmap) - from ..losses.utils import weight_reduce_loss from ..registry import HEADS from .retina_head import RetinaHead @@ -100,7 +99,7 @@ def _get_targets_single(self, self.background_label, dtype=torch.long) label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) - pos_gt_inds = anchors.new_full((num_valid_anchors,), + pos_gt_inds = anchors.new_full((num_valid_anchors, ), -1, dtype=torch.long) @@ -250,8 +249,8 @@ def argmax(x): for score, label in zip(scores, labels)]) return num_correct.float() / num_pos - def collect_loss_level_single(self, cls_loss, reg_loss, - assigned_gt_inds, labels_seq): + def collect_loss_level_single(self, cls_loss, reg_loss, assigned_gt_inds, + labels_seq): """Get the average loss in each FPN level w.r.t. each gt label Args: