Skip to content

Commit

Permalink
Code for Panoptic FPN
Browse files Browse the repository at this point in the history
* add configs for Panoptic FPN

* modify ./mmdet/core/evaluation/coco_utils.py for saving semantic segmentation results

* modify ./mmdet/datasets/custom.py for reading semantic segmentation annotations

* modify ./mmdet/models/detectors/two_stage.py for Panoptic FPN training

* add ./mmdet/models/mask_heads/semantic_segm_head.py for Panoptic FPN training

* modify ./mmdet/models/detectors/test_mixins.py for Panoptic FPN testing

* modify ./tools/test.py for Panoptic segmentation evaluation

* modify ./mmdet/models/losses/cross_entropy_loss.py for supporting training with ignore label
  • Loading branch information
GT9505 committed Jul 10, 2019
1 parent 084a389 commit f9da4f3
Show file tree
Hide file tree
Showing 14 changed files with 576 additions and 12 deletions.
40 changes: 40 additions & 0 deletions configs/panopticFPN/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Panoptic FPN

## Introduction

```
@inproceedings{Kirillov_2019_CVPR,
title={Panoptic Feature Pyramid Networks},
author={Kirillov, Alexander and Girshick, Ross and He, Kaiming and Dollar, Piotr},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2019}
}
```

## Preparation

a. Download 2017 [Panopatic Train/Val annotations](http://cocodataset.org/#download) for COCO 2017

b. Install [COCO 2018 Panoptic Segmentation Task API](https://github.com/cocodataset/panopticapi)

c. symlink the panoptic segmentation annotations root to `$MMDETECTION/data/coco/PanopticSegm_annotations`

Note: The path of panoptic segmentaion annotations should be `$MMDETECTION/data/coco/PanopticSegm_annotations/PanopticSegm_annotations/`

d. Extract semantic segmentation from data in COCO panoptic format by using the script `$PANOPTICAPI/converters/panoptic2semantic_segmentation.py`

Note: The path of semantic segmentaion annotations should be `$MMDETECTION/data/coco/PanopticSegm_annotations/SemanticSegm_annotations/semantic_val2017` and `$MMDETECTION/data/coco/PanopticSegm_annotations/SemanticSegm_annotations/semantic_train2017`

## Evaluation

a. using the script `$PANOPTICAPI/combine_semantic_and_instance_predictions.py` to get panoptic segmentation results

b. using the script `$PANOPTICAPI/evaluation.py` to evaluate panoptic segmentation results

## Results and Models

| Backbone | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | PQ
|:-----------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|
| R-50 | 1x | 9.6 | 0.831 |5.8 | 38.8 |

The model is trained and tested on 8 Titan Xp GPUs
205 changes: 205 additions & 0 deletions configs/panopticFPN/panopticFPN_r50_1x.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# model settings
model = dict(
type='PanopticFPN',
pretrained='modelzoo://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
semantic_segm_head=dict(
type='SemanticSegmHead',
in_channels=256,
out_channels=128,
start_level=0,
end_level=3,
num_things_classes=80,
num_classes=134,
ignore_label=200,
loss_semantic_segm=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.5),
norm_cfg=dict(type='GN', num_groups=8, requires_grad=True)),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=81,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100,
mask_thr_binary=0.5))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5,
with_mask=True,
with_crowd=True,
with_label=True,
with_semantic_seg=True,
semantic_labels_map=data_root +
'PanopticSegm_annotations/PanopticSegm_annotations/' +
'panoptic_val2017.json',
seg_prefix=data_root +
'PanopticSegm_annotations/SemanticSegm_annotations/' +
'semantic_train2017',
),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=True,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_label=False,
with_semantic_seg=False,
semantic_labels_map=data_root +
'PanopticSegm_annotations/PanopticSegm_annotations/' +
'panoptic_val2017.json',
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/panopticFPN_r50_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
40 changes: 38 additions & 2 deletions mmdet/core/evaluation/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,56 @@ def segm2json(dataset, results):
return bbox_json_results, segm_json_results


def results2json(dataset, results, out_file):
def SemanticSegm2json(dataset, results):
semantic_segm_json_results = []
for idx in range(len(dataset)):
img_id = dataset.img_ids[idx]
semantic_segm = results[idx]
for label in range(len(semantic_segm)):
# skip none segm
if semantic_segm[label] == []:
continue
semantic_segm_cls = semantic_segm[label][0]
data = dict()
data['image_id'] = img_id
data['category_id'] = dataset.label2semantic[label+1]
semantic_segm_cls['counts'] = \
semantic_segm_cls['counts'].decode()
data['segmentation'] = semantic_segm_cls
semantic_segm_json_results.append(data)
return semantic_segm_json_results


def results2json(dataset, results, out_file, eval_types):
result_files = dict()
if isinstance(results[0], list):
json_results = det2json(dataset, results)
result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox')
result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox')
mmcv.dump(json_results, result_files['bbox'])
elif isinstance(results[0], tuple):
json_results = segm2json(dataset, results)
if 'semantic_segm' in eval_types:
instance_segm_results = []
semantic_segm_results = []
for idx in range(len(dataset)):
det, seg, semantic_segm = results[idx]
instance_segm_results.append([det, seg])
semantic_segm_results.append(semantic_segm)
else:
instance_segm_results = results
json_results = segm2json(dataset, instance_segm_results)
result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox')
result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox')
result_files['segm'] = '{}.{}.json'.format(out_file, 'segm')
mmcv.dump(json_results[0], result_files['bbox'])
mmcv.dump(json_results[1], result_files['segm'])
if 'semantic_segm' in eval_types:
semantic_segm_json_results = SemanticSegm2json(
dataset, semantic_segm_results)
result_files['semantic_segm'] = '{}.{}.json'.format(
out_file, 'semantic_segm')
mmcv.dump(semantic_segm_json_results,
result_files['semantic_segm'])
elif isinstance(results[0], np.ndarray):
json_results = proposal2json(dataset, results)
result_files['proposal'] = '{}.{}.json'.format(out_file, 'proposal')
Expand Down
24 changes: 24 additions & 0 deletions mmdet/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self,
with_crowd=True,
with_label=True,
with_semantic_seg=False,
semantic_labels_map=None,
seg_prefix=None,
seg_scale_factor=1,
extra_aug=None,
Expand Down Expand Up @@ -101,6 +102,21 @@ def __init__(self,
self.with_label = with_label
# with semantic segmentation (stuff) annotation or not
self.with_seg = with_semantic_seg

self.semantic_labels_map = semantic_labels_map
# map semantic labels to ids
if self.semantic_labels_map:
import json
panoptic_json_file = json.load(open(semantic_labels_map, 'rb'))
self.semantic2label = {
cat_id['id']: i + 1
for i, cat_id in enumerate(panoptic_json_file['categories'])
}
self.label2semantic = {
i + 1: cat_id['id']
for i, cat_id in enumerate(panoptic_json_file['categories'])
}

# prefix of semantic segmentation map path
self.seg_prefix = seg_prefix
# rescale factor for segmentation maps
Expand Down Expand Up @@ -226,6 +242,14 @@ def prepare_train_img(self, idx):
gt_seg = self.seg_transform(gt_seg.squeeze(), img_scale, flip)
gt_seg = mmcv.imrescale(
gt_seg, self.seg_scale_factor, interpolation='nearest')

# map segmantic ids to labels
if self.semantic_labels_map:
gt_seg_unique = np.unique(gt_seg)
for i in gt_seg_unique:
if i != 0:
gt_seg[gt_seg == i] = self.semantic2label[i]

gt_seg = gt_seg[None, ...]
if self.proposals is not None:
proposals = self.bbox_transform(proposals, img_shape, scale_factor,
Expand Down
4 changes: 2 additions & 2 deletions mmdet/models/builder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from torch import nn

from mmdet.utils import build_from_cfg
from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS,
LOSSES, DETECTORS)
from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS,
SHARED_HEADS, HEADS, LOSSES, DETECTORS)


def build(cfg, registry, default_args=None):
Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from .fcos import FCOS
from .grid_rcnn import GridRCNN
from .mask_scoring_rcnn import MaskScoringRCNN
from .panoptic_fpn import PanopticFPN

__all__ = [
'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN'
'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'PanopticFPN'
]
5 changes: 5 additions & 0 deletions mmdet/models/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def __init__(self):
def with_neck(self):
return hasattr(self, 'neck') and self.neck is not None

@property
def with_semantic_segm_head(self):
return hasattr(self, 'semantic_segm_head') and \
self.semantic_segm_head is not None

@property
def with_shared_head(self):
return hasattr(self, 'shared_head') and self.shared_head is not None
Expand Down
Loading

0 comments on commit f9da4f3

Please sign in to comment.