Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support ISA module #70

Merged
merged 14 commits into from
Sep 9, 2021
45 changes: 45 additions & 0 deletions configs/_base_/models/isanet_r50-d8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='ISAHead',
in_channels=2048,
in_index=3,
channels=512,
isa_channels=256,
down_factor=(8, 8),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
56 changes: 56 additions & 0 deletions configs/isanet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Interlaced Sparse Self-Attention for Semantic Segmentation

## Introduction

<!-- [ALGORITHM] -->

```
LayneH marked this conversation as resolved.
Show resolved Hide resolved
@article{huang2019isa,
title={Interlaced Sparse Self-Attention for Semantic Segmentation},
author={Huang Lang and Yuan Yuhui and Guo Jianyuan and Zhang Chao and Chen Xilin and Wang Jingdong},
journal={arXiv preprint arXiv:1907.12273},
year={2019}
}

The technical report above is also presented at:
@article{yuan2021ocnet,
title={OCNet: Object Context for Semantic Segmentation},
author={Yuan, Yuhui and Huang, Lang and Guo, Jianyuan and Zhang, Chao and Chen, Xilin and Wang, Jingdong},
journal={International Journal of Computer Vision},
pages={1--24},
year={2021},
publisher={Springer}
}
```

## Results and models

### Cityscapes

| Method | Backbone | Crop Size | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|--------|----------|-----------|-----------|--------:|----------|----------------|------:|--------------:|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| ISANet | R-101-D8 | 512x1024 | 8 | 40000 | | | 79.32| |[model](https://drive.google.com/file/d/1oWAcDwj_ILRvwWp-bcGKJ7g1QlrT7Myo/view?usp=sharing)/[log](https://drive.google.com/file/d/1oWPEtE16FYF4P4LMl1uf_qiElPaOJ0-y/view?usp=sharing) |
| ISANet | R-101-D8 | 512x1024 | 16 | 40000 | | | 79.56 | | |
| ISANet | R-101-D8 | 512x1024 | 8 | 80000 | | | 79.67 | | |
| ISANet | R-101-D8 | 512x1024 | 16 | 80000 | | | 80.18 | | |
| NonLocal | R-101-D8 | 512x1024 | 8 | 40000 | 10.9 | 1.95 | 78.66 |
| NonLocal | R-101-D8 | 512x1024 | 8 | 80000 | - | - | 78.93 |

### ADE20K

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|--------|----------|-----------|--------:|----------|----------------|------:|--------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| ISANet | R-50-D8 | 512x512 | 80000 | | | | | |
| ISANet | R-101-D8 | 512x512 | 80000 | | | | | |
| ISANet | R-50-D8 | 512x512 | 160000 | | | | | |
| ISANet | R-101-D8 | 512x512 | 160000 | | | 43.77| | |
| NonLocal | R-101-D8 | 512x512 | 160000 | - | - | 43.36 | | |

### Pascal VOC 2012 + Aug

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|--------|----------|-----------|--------:|----------|----------------|------:|--------------:|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| ISANet | R-50-D8 | 512x512 | 20000 | | | | | |
| ISANet | R-101-D8 | 512x512 | 20000 | | | | | |
| ISANet | R-50-D8 | 512x512 | 40000 | | | | | |
| ISANet | R-101-D8 | 512x512 | 40000 | | | | | |
8 changes: 8 additions & 0 deletions configs/isanet/isanet.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Collections:
- Metadata:
Training Data:
- Cityscapes
- ADE20K
- Pascal VOC 2012 + Aug
Name: isanet
Models: []
5 changes: 5 additions & 0 deletions configs/isanet/isanet_r101-d8_512x1024_40k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/isanet_r50-d8.py', '../_base_/datasets/cityscapes.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
]
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
2 changes: 2 additions & 0 deletions configs/isanet/isanet_r101-d8_512x1024_80k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_base_ = './isanet_r50-d8_512x1024_80k_cityscapes.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
2 changes: 2 additions & 0 deletions configs/isanet/isanet_r101-d8_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_base_ = './isanet_r50-d8_512x512_160k_ade20k.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
2 changes: 2 additions & 0 deletions configs/isanet/isanet_r101-d8_512x512_20k_voc12aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_base_ = './isanet_r50-d8_512x512_20k_voc12aug.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
2 changes: 2 additions & 0 deletions configs/isanet/isanet_r101-d8_512x512_40k_voc12aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_base_ = './isanet_r50-d8_512x512_40k_voc12aug.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
2 changes: 2 additions & 0 deletions configs/isanet/isanet_r101-d8_512x512_80k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_base_ = './isanet_r50-d8_512x512_80k_ade20k.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
2 changes: 2 additions & 0 deletions configs/isanet/isanet_r101-d8_769x769_40k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_base_ = './isanet_r50-d8_769x769_40k_cityscapes.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
2 changes: 2 additions & 0 deletions configs/isanet/isanet_r101-d8_769x769_80k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_base_ = './isanet_r50-d8_769x769_80k_cityscapes.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
4 changes: 4 additions & 0 deletions configs/isanet/isanet_r50-d8_512x1024_40k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = [
'../_base_/models/isanet_r50-d8.py', '../_base_/datasets/cityscapes.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
]
4 changes: 4 additions & 0 deletions configs/isanet/isanet_r50-d8_512x1024_80k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = [
'../_base_/models/isanet_r50-d8.py', '../_base_/datasets/cityscapes.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
7 changes: 7 additions & 0 deletions configs/isanet/isanet_r50-d8_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = [
'../_base_/models/isanet_r50-d8.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
model = dict(
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
test_cfg = dict(mode='whole')
LayneH marked this conversation as resolved.
Show resolved Hide resolved
7 changes: 7 additions & 0 deletions configs/isanet/isanet_r50-d8_512x512_20k_voc12aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = [
'../_base_/models/isanet_r50-d8.py',
'../_base_/datasets/pascal_voc12_aug.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_20k.py'
]
model = dict(
decode_head=dict(num_classes=21), auxiliary_head=dict(num_classes=21))
7 changes: 7 additions & 0 deletions configs/isanet/isanet_r50-d8_512x512_40k_voc12aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = [
'../_base_/models/isanet_r50-d8.py',
'../_base_/datasets/pascal_voc12_aug.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_40k.py'
]
model = dict(
decode_head=dict(num_classes=21), auxiliary_head=dict(num_classes=21))
7 changes: 7 additions & 0 deletions configs/isanet/isanet_r50-d8_512x512_80k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = [
'../_base_/models/isanet_r50-d8.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
model = dict(
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
test_cfg = dict(mode='whole')
LayneH marked this conversation as resolved.
Show resolved Hide resolved
9 changes: 9 additions & 0 deletions configs/isanet/isanet_r50-d8_769x769_40k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = [
'../_base_/models/isanet_r50-d8.py',
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_40k.py'
]
model = dict(
decode_head=dict(align_corners=True),
auxiliary_head=dict(align_corners=True))
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))
LayneH marked this conversation as resolved.
Show resolved Hide resolved
9 changes: 9 additions & 0 deletions configs/isanet/isanet_r50-d8_769x769_80k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = [
'../_base_/models/isanet_r50-d8.py',
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_80k.py'
]
model = dict(
decode_head=dict(align_corners=True),
auxiliary_head=dict(align_corners=True))
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))
LayneH marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion mmseg/models/decode_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .fcn_head import FCNHead
from .fpn_head import FPNHead
from .gc_head import GCHead
from .isa_head import ISAHead
from .lraspp_head import LRASPPHead
from .nl_head import NLHead
from .ocr_head import OCRHead
Expand All @@ -29,5 +30,5 @@
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
'SETRMLAHead', 'SegformerHead'
'SETRMLAHead', 'SegformerHead', 'ISAHead'
]
144 changes: 144 additions & 0 deletions mmseg/models/decode_heads/isa_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import math

import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule

from ..builder import HEADS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead


class SelfAttentionBlock(_SelfAttentionBlock):
"""Self-Attention Module.

Args:
in_channels (int): Input channels of key/query feature.
channels (int): Output channels of key/query transform.
LayneH marked this conversation as resolved.
Show resolved Hide resolved
share_key_query (bool): Whether share projection weight between key
and query projection.
conv_cfg (dict|None): Config of conv layers.
LayneH marked this conversation as resolved.
Show resolved Hide resolved
norm_cfg (dict|None): Config of norm layers.
LayneH marked this conversation as resolved.
Show resolved Hide resolved
act_cfg (dict|None): Config of activation layers.
LayneH marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg):
super(SelfAttentionBlock, self).__init__(
key_in_channels=in_channels,
query_in_channels=in_channels,
channels=channels,
out_channels=in_channels,
share_key_query=False,
query_downsample=None,
key_downsample=None,
key_query_num_convs=2,
key_query_norm=True,
value_out_num_convs=1,
value_out_norm=False,
matmul_norm=True,
with_out=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)

self.output_project = self.build_project(
in_channels,
in_channels,
num_convs=1,
use_conv_module=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)

def forward(self, x):
"""Forward function."""
context = super(SelfAttentionBlock, self).forward(x, x)
return self.output_project(context)


@HEADS.register_module()
class ISAHead(BaseDecodeHead):
"""Interlaced Sparse Self-Attention for Semantic Segmentation.

This head is the implementation of `ISA
<https://arxiv.org/abs/1907.12273>`_.

Args:
isa_channels (int): The channels of ISA Module.
LayneH marked this conversation as resolved.
Show resolved Hide resolved
down_factor (tuple[int]): The local group size of ISA.
"""

def __init__(self, isa_channels, down_factor=(8, 8), **kwargs):
super(ISAHead, self).__init__(**kwargs)
self.down_factor = down_factor

self.in_conv = ConvModule(
self.in_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.global_relation = SelfAttentionBlock(
self.channels,
isa_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.local_relation = SelfAttentionBlock(
self.channels,
isa_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.out_conv = ConvModule(
self.channels * 2,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)

def forward(self, inputs):
"""Forward function."""
x_ = self._transform_inputs(inputs)
x = self.in_conv(x_)
residual = x

n, c, h, w = x.size()
loc_h, loc_w = self.down_factor # size of local group in H- and W-axes
glb_h, glb_w = math.ceil(h / loc_h), math.ceil(w / loc_w)
pad_h, pad_w = glb_h * loc_h - h, glb_w * loc_w - w
if pad_h > 0 or pad_w > 0: # pad if the size is not divisible
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
pad_h - pad_h // 2)
x = F.pad(x, padding)

# global relation
x = x.view(n, c, glb_h, loc_h, glb_w, loc_w)
# do permutation to gather global group
x = x.permute(0, 3, 5, 1, 2, 4) # (n, loc_h, loc_w, c, glb_h, glb_w)
x = x.reshape(-1, c, glb_h, glb_w)
# apply attention within each global group
x = self.global_relation(x) # (n * loc_h * loc_w, c, glb_h, glb_w)

# local relation
x = x.view(n, loc_h, loc_w, c, glb_h, glb_w)
# do permutation to gather local group
x = x.permute(0, 4, 5, 3, 1, 2) # (n, glb_h, glb_w, c, loc_h, loc_w)
x = x.reshape(-1, c, loc_h, loc_w)
# apply attention within each local group
x = self.local_relation(x) # (n * glb_h * glb_w, c, loc_h, loc_w)

# permute each pixel back to its original position
x = x.view(n, glb_h, glb_w, c, loc_h, loc_w)
x = x.permute(0, 3, 1, 4, 2, 5) # (n, c, glb_h, loc_h, glb_w, loc_w)
x = x.reshape(n, c, glb_h * loc_h, glb_w * loc_w)
if pad_h > 0 or pad_w > 0: # remove padding
x = x[:, :, pad_h // 2:pad_h // 2 + h, pad_w // 2:pad_w // 2 + w]

x = self.out_conv(torch.cat([x, residual], dim=1))
out = self.cls_seg(x)

return out
1 change: 1 addition & 0 deletions model-index.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Import:
- configs/fp16/fp16.yml
- configs/gcnet/gcnet.yml
- configs/hrnet/hrnet.yml
- configs/isanet/isanet.yml
- configs/mobilenet_v2/mobilenet_v2.yml
- configs/mobilenet_v3/mobilenet_v3.yml
- configs/nonlocal_net/nonlocal_net.yml
Expand Down
5 changes: 5 additions & 0 deletions tests/test_models/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ def test_emanet_forward():
'emanet/emanet_r50-d8_512x1024_80k_cityscapes.py')


def test_isanet_forward():
_test_encoder_decoder_forward(
'isanet/isanet_r50-d8_512x1024_40k_cityscapes.py')


def get_world_size(process_group):

return 1
Expand Down
Loading