Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSAF on mmdet v2.0 #2520

Closed
wants to merge 12 commits into from
39 changes: 39 additions & 0 deletions configs/fsaf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Feature Selective Anchor-Free Module for Single-Shot Object Detection

FSAF is an anchor-free method published in CVPR2019 ([https://arxiv.org/pdf/1903.00621.pdf](https://arxiv.org/pdf/1903.00621.pdf)).
Actually it is equivalent to the anchor-based method with only one anchor at each feature map position in each FPN level.
And this is how we implemented it.
Only the anchor-free branch is released for its better compatibility with the current framework and less computational budget.

In the original paper, feature maps within the central 0.2-0.5 area of a gt box are tagged as ignored. However,
it is empirically found that a hard threshold (0.2-0.2) gives a further gain on the performance. (see the table below)

## Main Results
### Results on R50/R101/X101-FPN

| Backbone | ignore range | ms-train| Lr schd | Train Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
|:----------:| :-------: |:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:|
| R-50 | 0.2-0.5 | N | 1x | 2.97 | 0.43 | 12.3 | 35.7 | fsaf-r50-fpn-1x-20191226-ee860779ad09031f7a58d193b0438a26.pth |
| R-50 | 0.2-0.2 | N | 1x | 2.97 | 0.43 | 13.0 | 37.0 | fsaf-r50-fpn-1x-20191225-d388a744213c3bb187e073f5ccdde5d6.pth |
| R-101 | 0.2-0.5 | N | 1x | 4.87 | 0.58 | 10.6 | 37.8 | fsaf-r101-fpn-1x-20191226-736730f8db59ac0a28262034484ed57d.pth |
| R-101 | 0.2-0.2 | N | 1x | 4.87 | 0.58 | 10.8 | 39.1 | fsaf-r101-fpn-1x-20191225-e1dbbcba40933cd8fc0d0174d1b13aa7.pth |
| X-101 | 0.2-0.2 | N | 1x | 9.02 | 1.23 | 5.6 | 41.8 | fsaf-x101-64x4d-fpn-1x-20191225-82d23b4bc07f2d666eed71fe48de49a9.pth |

**Notes:**
- *1x and 2x mean the model is trained for 12 and 24 epochs, respectively.*
- *All results are obtained with a single model and single-scale test.*
- *X-101 backbone represents ResNext-101-64x4d.*
- *All pretrained backbones use pytorch style.*
- *All models are trained on 8 Titan-XP gpus and tested on a single gpu.*

## Citations
BibTeX reference is as follows.
```
@inproceedings{zhu2019feature,
title={Feature Selective Anchor-Free Module for Single-Shot Object Detection},
author={Zhu, Chenchen and He, Yihui and Savvides, Marios},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
pages={840--849},
year={2019}
}
```
2 changes: 2 additions & 0 deletions configs/fsaf/fsaf_r101_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_base_ = './fsaf_r50_fpn_1x_coco.py'
model = dict(pretrained='torchvision://resnet101', backbone=dict(depth=101))
48 changes: 48 additions & 0 deletions configs/fsaf/fsaf_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
_base_ = '../retinanet/retinanet_r50_fpn_1x_coco.py'
# model settings
model = dict(
type='FSAF',
bbox_head=dict(
type='FSAFHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
reg_decoded_bbox=True,
anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=1,
scales_per_octave=1,
ratios=[1.0],
strides=[8, 16, 32, 64, 128],
center_offset=0.5),
bbox_coder=dict(_delete_=True, type='TBLRBBoxCoder', normalizer=1.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0,
reduction='none'),
loss_bbox=dict(
_delete_=True,
type='IoULoss',
eps=1e-6,
loss_weight=1.0,
reduction='none'),
))

# training and testing settings
train_cfg = dict(
assigner=dict(
_delete_=True,
type='EffectiveAreaAssigner',
pos_area_thr=0.2,
neg_area_thr=0.2,
min_pos_iof=0.01),
allowed_border=-1,
pos_weight=-1,
debug=False)
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=10, norm_type=2))
13 changes: 13 additions & 0 deletions configs/fsaf/fsaf_x101_64x4d_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_base_ = './fsaf_r50_fpn_1x_coco.py'
model = dict(
pretrained='open-mmlab://resnext101_64x4d',
backbone=dict(
type='ResNeXt',
depth=101,
groups=64,
base_width=4,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
style='pytorch'))
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
type='Resize',
img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736),
(1333, 768), (1333, 800)],
multiscale_mode="value",
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
type='Resize',
img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736),
(1333, 768), (1333, 800)],
multiscale_mode="value",
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
Expand Down
2 changes: 1 addition & 1 deletion demo/inference_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
8 changes: 5 additions & 3 deletions mmdet/core/bbox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .assigners import AssignResult, BaseAssigner, MaxIoUAssigner
from .coder import BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder
from .assigners import (AssignResult, BaseAssigner, EffectiveAreaAssigner,
MaxIoUAssigner)
from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder,
TBLRBBoxCoder)
from .iou_calculators import BboxOverlaps2D, bbox_overlaps
from .samplers import (BaseSampler, CombinedSampler,
InstanceBalancedPosSampler, IoUBalancedNegSampler,
Expand All @@ -17,5 +19,5 @@
'SamplingResult', 'build_assigner', 'build_sampler', 'bbox_flip',
'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result',
'distance2bbox', 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder',
'DeltaXYWHBBoxCoder'
'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'EffectiveAreaAssigner'
]
3 changes: 2 additions & 1 deletion mmdet/core/bbox/assigners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from .assign_result import AssignResult
from .atss_assigner import ATSSAssigner
from .base_assigner import BaseAssigner
from .effective_area_assigner import EffectiveAreaAssigner
from .max_iou_assigner import MaxIoUAssigner
from .point_assigner import PointAssigner

__all__ = [
'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
'PointAssigner', 'ATSSAssigner'
'PointAssigner', 'ATSSAssigner', 'EffectiveAreaAssigner'
]
237 changes: 237 additions & 0 deletions mmdet/core/bbox/assigners/effective_area_assigner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import torch

from ..iou_calculators import build_iou_calculator
from ..registry import BBOX_ASSIGNERS
from .assign_result import AssignResult
from .base_assigner import BaseAssigner


def scale_boxes(bboxes, scale):
"""Expand an array of boxes by a given scale.
Args:
bboxes (Tensor): shape (m, 4)
scale (float): the scale factor of bboxes

Returns:
(Tensor): shape (m, 4) scaled bboxes
"""
w_half = (bboxes[:, 2] - bboxes[:, 0]) * .5
h_half = (bboxes[:, 3] - bboxes[:, 1]) * .5
x_c = (bboxes[:, 2] + bboxes[:, 0]) * .5
y_c = (bboxes[:, 3] + bboxes[:, 1]) * .5

w_half *= scale
h_half *= scale

boxes_exp = torch.zeros_like(bboxes)
boxes_exp[:, 0] = x_c - w_half
boxes_exp[:, 2] = x_c + w_half
boxes_exp[:, 1] = y_c - h_half
boxes_exp[:, 3] = y_c + h_half
return boxes_exp


def is_located_in(points, bboxes, is_aligned=False):
""" is center a locates in box b
Then we compute the area of intersect between box_a and box_b.
Args:
points: (tensor) bounding boxes, Shape: [m,2].
bboxes: (tensor) bounding boxes, Shape: [n,4].
If is_aligned is ``True``, then m mush be equal to n
Return:
(tensor) intersection area, Shape: [m, n]. If is_aligned ``True``,
then shape = [m]
"""
if not is_aligned:
return (points[:, 0].unsqueeze(1) > bboxes[:, 0].unsqueeze(0)) & \
(points[:, 0].unsqueeze(1) < bboxes[:, 2].unsqueeze(0)) & \
(points[:, 1].unsqueeze(1) > bboxes[:, 1].unsqueeze(0)) & \
(points[:, 1].unsqueeze(1) < bboxes[:, 3].unsqueeze(0))
else:
return (points[:, 0] > bboxes[:, 0]) & \
(points[:, 0] < bboxes[:, 2]) & \
(points[:, 1] > bboxes[:, 1]) & \
(points[:, 1] < bboxes[:, 3])


def bboxes_area(bboxes):
"""Compute the area of an array of boxes."""
w = (bboxes[:, 2] - bboxes[:, 0])
h = (bboxes[:, 3] - bboxes[:, 1])
areas = w * h

return areas


@BBOX_ASSIGNERS.register_module
class EffectiveAreaAssigner(BaseAssigner):
"""Assign a corresponding gt bbox or background to each bbox.

Each proposals will be assigned with `-1`, `0`, or a positive integer
indicating the ground truth index.

- -1: don't care
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt

Args:
pos_area_thr (float): threshold within which pixels are
labelled as positive.
neg_area_thr (float): threshold above which pixels are
labelled as positive.
min_pos_iof (float): minimum iof of a pixel with a gt to be
labelled as positive
ignore_gt_area_thr (float): threshold within which the pixels
are ignored when the gt is labelled as ignored
"""

def __init__(self,
pos_area_thr,
neg_area_thr,
min_pos_iof=1e-2,
ignore_gt_area_thr=0.5,
iou_calculator=dict(type='BboxOverlaps2D')):
self.pos_area_thr = pos_area_thr
self.neg_area_thr = neg_area_thr
self.min_pos_iof = min_pos_iof
self.ignore_gt_area_thr = ignore_gt_area_thr
self.iou_calculator = build_iou_calculator(iou_calculator)

def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
"""Assign gt to bboxes.

This method assign a gt bbox to every bbox (proposal/anchor), each bbox
will be assigned with -1, 0, or a positive number. -1 means don't care,
0 means negative sample, positive number is the index (1-based) of
assigned gt.

Args:
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`, e.g., crowd boxes in COCO.
gt_labels (Tensor, optional): Label of gt_bboxes, shape (num_gt, ).

Returns:
:obj:`AssignResult`: The assign result.
"""
if bboxes.shape[0] == 0 or gt_bboxes.shape[0] == 0:
raise ValueError('No gt or bboxes')
bboxes = bboxes[:, :4]

# constructing effective gt areas
gt_eff = scale_boxes(gt_bboxes, self.pos_area_thr)
# effective bboxes, i.e. center 0.2 part
bbox_centers = (bboxes[:, 2:4] + bboxes[:, 0:2]) / 2
is_bbox_in_gt = is_located_in(bbox_centers, gt_bboxes)
# the center points lie within the gt boxes

# Only calculate bbox and gt_eff IoF. This enables small prior bboxes
# to match large gts
bbox_and_gt_eff_overlaps = self.iou_calculator(
bboxes, gt_eff, mode='iof')
is_bbox_in_gt_eff = is_bbox_in_gt & (
bbox_and_gt_eff_overlaps > self.min_pos_iof)
# shape (n, k)
# the center point of effective priors should be within the gt box

# constructing ignored gt areas
gt_ignore = scale_boxes(gt_bboxes, self.neg_area_thr)
is_bbox_in_gt_ignore = (
self.iou_calculator(bboxes, gt_ignore, mode='iof') >
self.min_pos_iof)
is_bbox_in_gt_ignore &= (~is_bbox_in_gt_eff)
# rule out center effective pixels

gt_areas = bboxes_area(gt_bboxes)
_, sort_idx = gt_areas.sort(descending=True)
# rank all gt bbox areas so that smaller instances
# can overlay larger ones

assigned_gt_inds = self.assign_one_hot_gt_indices(
is_bbox_in_gt_eff, is_bbox_in_gt_ignore, gt_priority=sort_idx)

# ignored gts
if gt_bboxes_ignore is not None and gt_bboxes_ignore.numel() > 0:
gt_bboxes_ignore = scale_boxes(
gt_bboxes_ignore, scale=self.ignore_gt_area_thr)
is_bbox_in_ignored_gts = is_located_in(bbox_centers,
gt_bboxes_ignore)
is_bbox_in_ignored_gts = is_bbox_in_ignored_gts.any(dim=1)
assigned_gt_inds[is_bbox_in_ignored_gts] = -1

num_bboxes, num_gts = is_bbox_in_gt_eff.shape
if gt_labels is not None:
assigned_labels = assigned_gt_inds.new_zeros((num_bboxes, ))
pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze()
if pos_inds.numel() > 0:
assigned_labels[pos_inds] = gt_labels[
assigned_gt_inds[pos_inds] - 1]
else:
assigned_labels = None

return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)

def assign_one_hot_gt_indices(self,
is_bbox_in_gt_eff,
is_bbox_in_gt_ignore,
gt_priority=None):
"""Assign only one gt index to each prior box
(smaller gt has higher priority)

Args:
is_bbox_in_gt_eff: shape [num_prior, num_gt].
bool tensor indicating the bbox center is in
the effective area of a gt (e.g. 0-0.2)
is_bbox_in_gt_ignore: shape [num_prior, num_gt].
bool tensor indicating the bbox
center is in the ignored area of a gt (e.g. 0.2-0.5)
gt_labels: shape [num_gt]. gt labels (0-80 for COCO)
gt_priority: shape [num_gt]. gt priorities.
The gt with a higher priority is more likely to be
assigned to the bbox when the bbox match with multiple gts

Returns:
:obj:`AssignResult`: The assign result.
"""
num_bboxes, num_gts = is_bbox_in_gt_eff.shape

if gt_priority is None:
gt_priority = torch.arange(num_gts).to(is_bbox_in_gt_eff.device)
# the bigger, the more preferable to be assigned
# the assigned inds are by default 0 (background)
assigned_gt_inds = is_bbox_in_gt_eff.new_full((num_bboxes, ),
0,
dtype=torch.long)
inds_of_match = torch.any(is_bbox_in_gt_eff, dim=1)
# matched bboxes (to any gt)
inds_of_ignore = torch.any(is_bbox_in_gt_ignore, dim=1)
# ignored indices
assigned_gt_inds[inds_of_ignore] = -1
if is_bbox_in_gt_eff.sum() == 0: # No gt match
return assigned_gt_inds

# The priority of each prior box and gt pair. If one prior box is
# matched bo multiple gts. Only the pair with the highest priority
# is saved
pair_priority = is_bbox_in_gt_eff.new_full((num_bboxes, num_gts),
-1,
dtype=torch.long)

# Each bbox could match with multiple gts.
# The following codes deal with this situation

# Whether a bbox match a gt, bool tensor, shape [num_match, num_gt]
matched_bbox_and_gt_correspondence = is_bbox_in_gt_eff[inds_of_match]
# The matched gt index of each positive bbox. Length >= num_match,
# since one bbox could match multiple gts
matched_bbox_gt_inds =\
torch.nonzero(matched_bbox_and_gt_correspondence)[:, 1]
# Assign priority to each bbox-gt pair.
pair_priority[is_bbox_in_gt_eff] = gt_priority[matched_bbox_gt_inds]
_, argmax_priority = pair_priority[inds_of_match].max(dim=1)
# the maximum shape [num_match]
# effective indices
assigned_gt_inds[inds_of_match] = argmax_priority + 1 # 1-based
return assigned_gt_inds
Loading