Skip to content

Add more test #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ jobs:
strategy:
matrix:
python-version: [3.6, 3.7]
torch: [1.3.0, 1.5.0]
torch: [1.3.0+cpu, 1.5.0+cpu]
include:
- torch: 1.3.0
torchvision: 0.4.2
- torch: 1.5.0
torchvision: 0.6.0
- python-version: 3.8
torch: 1.5.0
torchvision: 0.6.0
- torch: 1.3.0+cpu
torchvision: 0.4.2+cpu
- torch: 1.5.0+cpu
torchvision: 0.6.0+cpu
- torch: 1.5.0+cpu
torchvision: 0.6.0+cpu
python-version: 3.8
- torch: 1.5.0+cu101
torchvision: 0.6.0+cu101
python-version: 3.7

steps:
- uses: actions/checkout@v2
Expand All @@ -56,6 +59,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install CUDA
if: ${{matrix.torch == '1.5.0+cu101'}}
run: |
export INSTALLER=cuda-repo-${UBUNTU_VERSION}_${CUDA}_amd64.deb
wget http://developer.download.nvidia.com/compute/cuda/repos/${UBUNTU_VERSION}/x86_64/${INSTALLER}
Expand All @@ -70,13 +74,13 @@ jobs:
export PATH=${CUDA_HOME}/bin:${PATH}
sudo apt-get install -y ninja-build
- name: Install Pillow
if: ${{matrix.torchvision == '0.4.2+cpu'}}
run: pip install Pillow==6.2.2
if: ${{matrix.torchvision == '0.4.2'}}
- name: Install PyTorch
run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
- name: Install mmseg dependencies
run: |
pip install mmcv==1.0rc0+torch${{matrix.torch}}+cu101 -f https://openmmlab.oss-accelerate.aliyuncs.com/mmcv/dist/index.html
pip install mmcv==1.0rc0+torch${{matrix.torch}} -f https://openmmlab.oss-accelerate.aliyuncs.com/mmcv/dist/index.html
pip install -r requirements.txt
- name: Build and install
run: rm -rf .eggs && pip install -e .
Expand All @@ -87,7 +91,7 @@ jobs:
coverage report -m --omit="mmseg/utils/*","mmseg/apis/*"
# Only upload coverage report for python3.7 && pytorch1.5
- name: Upload coverage to Codecov
if: ${{matrix.torch == '1.5.0' && matrix.python-version == '3.7'}}
if: ${{matrix.torch == '1.5.0+cu101' && matrix.python-version == '3.7'}}
uses: codecov/codecov-action@v1.0.10
with:
file: ./coverage.xml
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ We wish that the toolbox and benchmark could serve the growing research
community by providing a flexible as well as standardized toolkit to reimplement existing methods
and develop their own new semantic segmentation methods.

Many thanks to Ruobing Han ([@drcut](https://github.com/drcut)), Xiaoming Ma([@aishangmaxiaoming](https://github.com/aishangmaxiaoming)), Shiguang Wang ([@sunnyxiaohu](https://github.com/aishangmaxiaoming)) for deployment support.
Many thanks to Ruobing Han ([@drcut](https://github.com/drcut)), Xiaoming Ma([@aishangmaxiaoming](https://github.com/aishangmaxiaoming)), Shiguang Wang ([@sunnyxiaohu](https://github.com/sunnyxiaohu)) for deployment support.

## Citation

Expand Down
8 changes: 4 additions & 4 deletions docs/tutorials/new_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ xxx
zzz
```
Only
`data/my_dataset/img_dir/train/xxx{img_suffix}.png`,
`data/my_dataset/img_dir/train/zzz{img_suffix}.png`,
`data/my_dataset/ann_dir/train/xxx{seg_map_suffix}.png`,
`data/my_dataset/ann_dir/train/zzz{seg_map_suffix}.png` will be loaded.
`data/my_dataset/img_dir/train/xxx{img_suffix}`,
`data/my_dataset/img_dir/train/zzz{img_suffix}`,
`data/my_dataset/ann_dir/train/xxx{seg_map_suffix}`,
`data/my_dataset/ann_dir/train/zzz{seg_map_suffix}` will be loaded.

## Customize datasets by mixing dataset

Expand Down
3 changes: 2 additions & 1 deletion docs/tutorials/training_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ MMSegmentation support following training tricks out of box.

## Different Learning Rate(LR) for Backbone and Heads

In semantic segmentation, some methods make the LR of heads larger than backbone to achieve better performance.
In semantic segmentation, some methods make the LR of heads larger than backbone to achieve better performance or faster convergence.

In MMSegmentation, you may add following lines to config to make the LR of heads 10 times of backbone.
```python
Expand All @@ -13,6 +13,7 @@ optimizer_config=dict(
custom_keys={
'head': dict(lr_mult=10.)}))
```
With this modification, the LR of any parameter group with `'head'` in name will be multiplied by 10.
You may refer to [MMCV doc](https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.DefaultOptimizerConstructor) for further details.

## Online Hard Example Mining (OHEM)
Expand Down
6 changes: 3 additions & 3 deletions mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ class CustomDataset(Dataset):
Args:
pipeline (list[dict]): Processing pipeline
img_dir (str): Path to image directory
img_suffix (str): Suffix of images. Default: '.png'
img_suffix (str): Suffix of images. Default: '.jpg'
ann_dir (str, optional): Path to annotation directory. Default: None
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
split (str, optional): Split txt file. If split is specified, only
file with suffix in the splits will be loaded. Otherwise, all
images in img_dir/ann_dir will be loaded. Default: None
data_root (str, optional): Data root for img_dir/ann_dir. Default:
None.
test_mode (str): If test_mode=True, gt wouldn't be loaded.
test_mode (bool): If test_mode=True, gt wouldn't be loaded.
ignore_index (int): The label index to be ignored. Default: 255
reduce_zero_label (bool): Whether to mark label zero as ignored.
Default: False
Expand All @@ -67,7 +67,7 @@ class CustomDataset(Dataset):
def __init__(self,
pipeline,
img_dir,
img_suffix='.png',
img_suffix='.jpg',
ann_dir=None,
seg_map_suffix='.png',
split=None,
Expand Down
22 changes: 15 additions & 7 deletions mmseg/models/decode_heads/enc_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_norm_layer

from mmseg.ops import Encoding
from mmseg.ops import Encoding, resize
from ..builder import HEADS, build_loss
from .decode_head import BaseDecodeHead

Expand All @@ -30,12 +30,16 @@ def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
act_cfg=act_cfg)
# TODO: resolve this hack
# change to 1d
encoding_norm_cfg = norm_cfg.copy()
if encoding_norm_cfg['type'] in ['BN', 'IN']:
encoding_norm_cfg['type'] += '1d'
if norm_cfg is not None:
encoding_norm_cfg = norm_cfg.copy()
if encoding_norm_cfg['type'] in ['BN', 'IN']:
encoding_norm_cfg['type'] += '1d'
else:
encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
'2d', '1d')
else:
encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
'2d', '1d')
# fallback to BN1d
encoding_norm_cfg = dict(type='BN1d')
self.encoding = nn.Sequential(
Encoding(channels=in_channels, num_codes=num_codes),
build_norm_layer(encoding_norm_cfg, num_codes)[1],
Expand Down Expand Up @@ -128,7 +132,11 @@ def forward(self, inputs):
feat = self.bottleneck(inputs[-1])
if self.add_lateral:
laterals = [
lateral_conv(inputs[i])
resize(
lateral_conv(inputs[i]),
size=feat.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
for i, lateral_conv in enumerate(self.lateral_convs)
]
feat = self.fusion(torch.cat([feat, *laterals], 1))
Expand Down
2 changes: 2 additions & 0 deletions mmseg/models/segmentors/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(self,

self.init_weights(pretrained=pretrained)

assert self.with_decode_head

def _init_decode_head(self, decode_head):
"""Initialize ``decode_head``"""
self.decode_head = builder.build_head(decode_head)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@
00001
00002
00003
00004
1 change: 1 addition & 0 deletions tests/data/pseudo_dataset/splits/val.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
00004
67 changes: 65 additions & 2 deletions tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os.path as osp
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

from mmseg.core.evaluation import get_classes, get_palette
from mmseg.datasets import (ADE20KDataset, CityscapesDataset, ConcatDataset,
CustomDataset, PascalVOCDataset, RepeatDataset)
Expand All @@ -13,13 +16,19 @@ def test_classes():
assert list(
ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k')

with pytest.raises(ValueError):
get_classes('unsupported')


def test_palette():
assert CityscapesDataset.PALETTE == get_palette('cityscapes')
assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette(
'pascal_voc')
assert ADE20KDataset.PALETTE == get_palette('ade') == get_palette('ade20k')

with pytest.raises(ValueError):
get_palette('unsupported')


@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
Expand Down Expand Up @@ -82,7 +91,7 @@ def test_custom_dataset():
])
]

# train dataset
# with img_dir and ann_dir
train_dataset = CustomDataset(
train_pipeline,
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
Expand All @@ -92,6 +101,17 @@ def test_custom_dataset():
seg_map_suffix='gt.png')
assert len(train_dataset) == 5

# with img_dir, ann_dir, split
train_dataset = CustomDataset(
train_pipeline,
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
img_dir='imgs/',
ann_dir='gts/',
img_suffix='img.jpg',
seg_map_suffix='gt.png',
split='splits/train.txt')
assert len(train_dataset) == 4

# no data_root
train_dataset = CustomDataset(
train_pipeline,
Expand All @@ -101,10 +121,53 @@ def test_custom_dataset():
seg_map_suffix='gt.png')
assert len(train_dataset) == 5

# test dataset
# with data_root but img_dir/ann_dir are abs path
train_dataset = CustomDataset(
train_pipeline,
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
img_dir=osp.abspath(
osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')),
ann_dir=osp.abspath(
osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts')),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5

# test_mode=True
test_dataset = CustomDataset(
test_pipeline,
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
img_suffix='img.jpg',
test_mode=True)
assert len(test_dataset) == 5

# training data get
train_data = train_dataset[0]
assert isinstance(train_data, dict)

# test data get
test_data = test_dataset[0]
assert isinstance(test_data, dict)

# get gt seg map
gt_seg_maps = train_dataset.get_gt_seg_maps()
assert len(gt_seg_maps) == 5

# evaluation
pseudo_results = []
for gt_seg_map in gt_seg_maps:
h, w = gt_seg_map.shape
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
eval_results = train_dataset.evaluate(pseudo_results)
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results

# evaluation with CLASSES
train_dataset.CLASSES = tuple(['a'] * 7)
eval_results = train_dataset.evaluate(pseudo_results)
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
Loading