From 5cb65218c85708bc6a19387b8fe6bd7fb3686efc Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Wed, 8 Jul 2020 09:39:11 +0000 Subject: [PATCH 1/6] add more test --- .github/workflows/build.yml | 23 +-- README.md | 2 +- docs/tutorials/new_dataset.md | 8 +- docs/tutorials/training_tricks.md | 3 +- mmseg/datasets/custom.py | 9 +- mmseg/models/decode_heads/enc_head.py | 22 ++- mmseg/models/segmentors/encoder_decoder.py | 2 + tests/data/pseudo_dataset/splits/all.txt | 5 - tests/test_data/test_dataset.py | 67 +++++++- tests/test_data/test_dataset_builder.py | 100 +++++++++++- tests/test_data/test_transform.py | 62 ++++++- tests/test_models/test_forward.py | 8 +- tests/test_models/test_heads.py | 180 ++++++++++++++++++++- 13 files changed, 445 insertions(+), 46 deletions(-) delete mode 100644 tests/data/pseudo_dataset/splits/all.txt diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3183d6a0b5..6bde4487ab 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -39,15 +39,17 @@ jobs: strategy: matrix: python-version: [3.6, 3.7] - torch: [1.3.0, 1.5.0] + torch: [1.3.0+cpu, 1.5.0+cpu] include: - - torch: 1.3.0 - torchvision: 0.4.2 - - torch: 1.5.0 - torchvision: 0.6.0 + - torch: 1.3.0+cpu + torchvision: 0.4.2+cpu + - torch: 1.5.0+cpu + torchvision: 0.6.0+cpu - python-version: 3.8 - torch: 1.5.0 - torchvision: 0.6.0 + torch: 1.5.0+cpu + torchvision: 0.6.0+cpu + - torch: 1.5.0+cu101 + torchvision: 0.6.0+cu101 steps: - uses: actions/checkout@v2 @@ -69,14 +71,15 @@ jobs: export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${CUDA_HOME}/include:${LD_LIBRARY_PATH} export PATH=${CUDA_HOME}/bin:${PATH} sudo apt-get install -y ninja-build + - if: ${{matrix.torch == '1.5.0+cu101'}} - name: Install Pillow run: pip install Pillow==6.2.2 - if: ${{matrix.torchvision == '0.4.2'}} + if: ${{matrix.torchvision == '0.4.2+cpu'}} - name: Install PyTorch run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html - name: Install mmseg dependencies run: | - pip install mmcv==1.0rc0+torch${{matrix.torch}}+cu101 -f https://openmmlab.oss-accelerate.aliyuncs.com/mmcv/dist/index.html + pip install mmcv==1.0rc0+torch${{matrix.torch}} -f https://openmmlab.oss-accelerate.aliyuncs.com/mmcv/dist/index.html pip install -r requirements.txt - name: Build and install run: rm -rf .eggs && pip install -e . @@ -87,7 +90,7 @@ jobs: coverage report -m --omit="mmseg/utils/*","mmseg/apis/*" # Only upload coverage report for python3.7 && pytorch1.5 - name: Upload coverage to Codecov - if: ${{matrix.torch == '1.5.0' && matrix.python-version == '3.7'}} + if: ${{matrix.torch == '1.5.0+cu101' && matrix.python-version == '3.7'}} uses: codecov/codecov-action@v1.0.10 with: file: ./coverage.xml diff --git a/README.md b/README.md index 1dc7dff996..1855741a3a 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ We wish that the toolbox and benchmark could serve the growing research community by providing a flexible as well as standardized toolkit to reimplement existing methods and develop their own new semantic segmentation methods. -Many thanks to Ruobing Han ([@drcut](https://github.com/drcut)), Xiaoming Ma([@aishangmaxiaoming](https://github.com/aishangmaxiaoming)), Shiguang Wang ([@sunnyxiaohu](https://github.com/aishangmaxiaoming)) for deployment support. +Many thanks to Ruobing Han ([@drcut](https://github.com/drcut)), Xiaoming Ma([@aishangmaxiaoming](https://github.com/aishangmaxiaoming)), Shiguang Wang ([@sunnyxiaohu](https://github.com/sunnyxiaohu)) for deployment support. ## Citation diff --git a/docs/tutorials/new_dataset.md b/docs/tutorials/new_dataset.md index 237629a436..0ad1019e0e 100644 --- a/docs/tutorials/new_dataset.md +++ b/docs/tutorials/new_dataset.md @@ -33,10 +33,10 @@ xxx zzz ``` Only -`data/my_dataset/img_dir/train/xxx{img_suffix}.png`, -`data/my_dataset/img_dir/train/zzz{img_suffix}.png`, -`data/my_dataset/ann_dir/train/xxx{seg_map_suffix}.png`, -`data/my_dataset/ann_dir/train/zzz{seg_map_suffix}.png` will be loaded. +`data/my_dataset/img_dir/train/xxx{img_suffix}`, +`data/my_dataset/img_dir/train/zzz{img_suffix}`, +`data/my_dataset/ann_dir/train/xxx{seg_map_suffix}`, +`data/my_dataset/ann_dir/train/zzz{seg_map_suffix}` will be loaded. ## Customize datasets by mixing dataset diff --git a/docs/tutorials/training_tricks.md b/docs/tutorials/training_tricks.md index 22270b93b0..5ff4b18a70 100644 --- a/docs/tutorials/training_tricks.md +++ b/docs/tutorials/training_tricks.md @@ -4,7 +4,7 @@ MMSegmentation support following training tricks out of box. ## Different Learning Rate(LR) for Backbone and Heads -In semantic segmentation, some methods make the LR of heads larger than backbone to achieve better performance. +In semantic segmentation, some methods make the LR of heads larger than backbone to achieve better performance or faster convergence. In MMSegmentation, you may add following lines to config to make the LR of heads 10 times of backbone. ```python @@ -13,6 +13,7 @@ optimizer_config=dict( custom_keys={ 'head': dict(lr_mult=10.)})) ``` +With this modification, the LR of any parameter group with `'head'` in name will be multiplied by 10. You may refer to [MMCV doc](https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.DefaultOptimizerConstructor) for further details. ## Online Hard Example Mining (OHEM) diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/custom.py index 1797e09451..1d0d7a49e3 100644 --- a/mmseg/datasets/custom.py +++ b/mmseg/datasets/custom.py @@ -46,7 +46,7 @@ class CustomDataset(Dataset): Args: pipeline (list[dict]): Processing pipeline img_dir (str): Path to image directory - img_suffix (str): Suffix of images. Default: '.png' + img_suffix (str): Suffix of images. Default: '.jpg' ann_dir (str, optional): Path to annotation directory. Default: None seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' split (str, optional): Split txt file. If split is specified, only @@ -54,7 +54,7 @@ class CustomDataset(Dataset): images in img_dir/ann_dir will be loaded. Default: None data_root (str, optional): Data root for img_dir/ann_dir. Default: None. - test_mode (str): If test_mode=True, gt wouldn't be loaded. + test_mode (bool): If test_mode=True, gt wouldn't be loaded. ignore_index (int): The label index to be ignored. Default: 255 reduce_zero_label (bool): Whether to mark label zero as ignored. Default: False @@ -67,7 +67,7 @@ class CustomDataset(Dataset): def __init__(self, pipeline, img_dir, - img_suffix='.png', + img_suffix='.jpg', ann_dir=None, seg_map_suffix='.png', split=None, @@ -95,6 +95,9 @@ def __init__(self, if not (self.split is None or osp.isabs(self.split)): self.split = osp.join(self.data_root, self.split) + print('loadding ann') + print(self.img_dir) + print(self.ann_dir) # load annotations self.img_infos = self.load_annotations(self.img_dir, self.img_suffix, self.ann_dir, diff --git a/mmseg/models/decode_heads/enc_head.py b/mmseg/models/decode_heads/enc_head.py index 5d48c8621e..0c11994cf6 100644 --- a/mmseg/models/decode_heads/enc_head.py +++ b/mmseg/models/decode_heads/enc_head.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from mmcv.cnn import ConvModule, build_norm_layer -from mmseg.ops import Encoding +from mmseg.ops import Encoding, resize from ..builder import HEADS, build_loss from .decode_head import BaseDecodeHead @@ -30,12 +30,16 @@ def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg): act_cfg=act_cfg) # TODO: resolve this hack # change to 1d - encoding_norm_cfg = norm_cfg.copy() - if encoding_norm_cfg['type'] in ['BN', 'IN']: - encoding_norm_cfg['type'] += '1d' + if norm_cfg is not None: + encoding_norm_cfg = norm_cfg.copy() + if encoding_norm_cfg['type'] in ['BN', 'IN']: + encoding_norm_cfg['type'] += '1d' + else: + encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace( + '2d', '1d') else: - encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace( - '2d', '1d') + # fallback to BN1d + encoding_norm_cfg = dict(type='BN1d') self.encoding = nn.Sequential( Encoding(channels=in_channels, num_codes=num_codes), build_norm_layer(encoding_norm_cfg, num_codes)[1], @@ -128,7 +132,11 @@ def forward(self, inputs): feat = self.bottleneck(inputs[-1]) if self.add_lateral: laterals = [ - lateral_conv(inputs[i]) + resize( + lateral_conv(inputs[i]), + size=feat.shape[2:], + mode='bilinear', + align_corners=self.align_corners) for i, lateral_conv in enumerate(self.lateral_convs) ] feat = self.fusion(torch.cat([feat, *laterals], 1)) diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index 679f28ccb7..fdec1d987b 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -37,6 +37,8 @@ def __init__(self, self.init_weights(pretrained=pretrained) + assert self.with_decode_head + def _init_decode_head(self, decode_head): """Initialize ``decode_head``""" self.decode_head = builder.build_head(decode_head) diff --git a/tests/data/pseudo_dataset/splits/all.txt b/tests/data/pseudo_dataset/splits/all.txt deleted file mode 100644 index c544a9a3cb..0000000000 --- a/tests/data/pseudo_dataset/splits/all.txt +++ /dev/null @@ -1,5 +0,0 @@ -00000 -00001 -00002 -00003 -00004 diff --git a/tests/test_data/test_dataset.py b/tests/test_data/test_dataset.py index beb0b31afc..ee6d2c47a8 100644 --- a/tests/test_data/test_dataset.py +++ b/tests/test_data/test_dataset.py @@ -1,6 +1,9 @@ import os.path as osp from unittest.mock import MagicMock, patch +import numpy as np +import pytest + from mmseg.core.evaluation import get_classes, get_palette from mmseg.datasets import (ADE20KDataset, CityscapesDataset, ConcatDataset, CustomDataset, PascalVOCDataset, RepeatDataset) @@ -13,6 +16,9 @@ def test_classes(): assert list( ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k') + with pytest.raises(ValueError): + get_classes('unsupported') + def test_palette(): assert CityscapesDataset.PALETTE == get_palette('cityscapes') @@ -20,6 +26,9 @@ def test_palette(): 'pascal_voc') assert ADE20KDataset.PALETTE == get_palette('ade') == get_palette('ade20k') + with pytest.raises(ValueError): + get_palette('unsupported') + @patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) @patch('mmseg.datasets.CustomDataset.__getitem__', @@ -82,7 +91,7 @@ def test_custom_dataset(): ]) ] - # train dataset + # with img_dir and ann_dir train_dataset = CustomDataset( train_pipeline, data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'), @@ -92,6 +101,17 @@ def test_custom_dataset(): seg_map_suffix='gt.png') assert len(train_dataset) == 5 + # with img_dir, ann_dir, split + train_dataset = CustomDataset( + train_pipeline, + data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'), + img_dir='imgs/', + ann_dir='gts/', + img_suffix='img.jpg', + seg_map_suffix='gt.png', + split='splits/train.txt') + assert len(train_dataset) == 4 + # no data_root train_dataset = CustomDataset( train_pipeline, @@ -101,10 +121,53 @@ def test_custom_dataset(): seg_map_suffix='gt.png') assert len(train_dataset) == 5 - # test dataset + # with data_root but img_dir/ann_dir are abs path + train_dataset = CustomDataset( + train_pipeline, + data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'), + img_dir=osp.abspath( + osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')), + ann_dir=osp.abspath( + osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts')), + img_suffix='img.jpg', + seg_map_suffix='gt.png') + assert len(train_dataset) == 5 + + # test_mode=True test_dataset = CustomDataset( test_pipeline, img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'), img_suffix='img.jpg', test_mode=True) assert len(test_dataset) == 5 + + # training data get + train_data = train_dataset[0] + assert isinstance(train_data, dict) + + # test data get + test_data = test_dataset[0] + assert isinstance(test_data, dict) + + # get gt seg map + gt_seg_maps = train_dataset.get_gt_seg_maps() + assert len(gt_seg_maps) == 5 + + # evaluation + pseudo_results = [] + for gt_seg_map in gt_seg_maps: + h, w = gt_seg_map.shape + pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) + eval_results = train_dataset.evaluate(pseudo_results) + assert isinstance(eval_results, dict) + assert 'mIoU' in eval_results + assert 'mAcc' in eval_results + assert 'aAcc' in eval_results + + # evaluation with CLASSES + train_dataset.CLASSES = tuple(['a'] * 7) + eval_results = train_dataset.evaluate(pseudo_results) + assert isinstance(eval_results, dict) + assert 'mIoU' in eval_results + assert 'mAcc' in eval_results + assert 'aAcc' in eval_results diff --git a/tests/test_data/test_dataset_builder.py b/tests/test_data/test_dataset_builder.py index 4e002c37b2..c6827e4d17 100644 --- a/tests/test_data/test_dataset_builder.py +++ b/tests/test_data/test_dataset_builder.py @@ -1,9 +1,12 @@ import math +import os.path as osp +import pytest from torch.utils.data import (DistributedSampler, RandomSampler, SequentialSampler) -from mmseg.datasets import DATASETS, build_dataloader, build_dataset +from mmseg.datasets import (DATASETS, ConcatDataset, build_dataloader, + build_dataset) @DATASETS.register_module() @@ -28,6 +31,101 @@ def test_build_dataset(): assert isinstance(dataset, ToyDataset) assert dataset.cnt == 1 + data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset') + img_dir = 'imgs/' + ann_dir = 'gts/' + + # We use same dir twice for simplicity + # with ann_dir + cfg = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + img_dir=[img_dir, img_dir], + ann_dir=[ann_dir, ann_dir]) + dataset = build_dataset(cfg) + assert isinstance(dataset, ConcatDataset) + assert len(dataset) == 10 + + # with ann_dir, split + cfg = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + img_dir=img_dir, + ann_dir=ann_dir, + split=['splits/train.txt', 'splits/val.txt']) + dataset = build_dataset(cfg) + assert isinstance(dataset, ConcatDataset) + assert len(dataset) == 5 + + # with ann_dir, split + cfg = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + img_dir=img_dir, + ann_dir=[ann_dir, ann_dir], + split=['splits/train.txt', 'splits/val.txt']) + dataset = build_dataset(cfg) + assert isinstance(dataset, ConcatDataset) + assert len(dataset) == 5 + + # test mode + cfg = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + img_dir=[img_dir, img_dir], + test_mode=True) + dataset = build_dataset(cfg) + assert isinstance(dataset, ConcatDataset) + assert len(dataset) == 10 + + # test mode with splits + cfg = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + img_dir=[img_dir, img_dir], + split=['splits/val.txt', 'splits/val.txt'], + test_mode=True) + dataset = build_dataset(cfg) + assert isinstance(dataset, ConcatDataset) + assert len(dataset) == 2 + + # len(ann_dir) should be zero or len(img_dir) when len(img_dir) > 1 + with pytest.raises(AssertionError): + cfg = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + img_dir=[img_dir, img_dir], + ann_dir=[ann_dir, ann_dir, ann_dir]) + build_dataset(cfg) + + # len(splits) should be zero or len(img_dir) when len(img_dir) > 1 + with pytest.raises(AssertionError): + cfg = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + img_dir=[img_dir, img_dir], + split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt']) + build_dataset(cfg) + + # len(splits) == len(ann_dir) when only len(img_dir) == 1 and len( + # ann_dir) > 1 + with pytest.raises(AssertionError): + cfg = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + img_dir=img_dir, + ann_dir=[ann_dir, ann_dir], + split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt']) + build_dataset(cfg) + def test_build_dataloader(): dataset = ToyDataset() diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index 50645f94bd..9c11b7032f 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -47,17 +47,52 @@ def test_resize(): results['pad_shape'] = img.shape results['scale_factor'] = 1.0 - results = resize_module(results) + resized_results = resize_module(results.copy()) + assert resized_results['img_shape'] == (750, 1333, 3) - results.pop('scale') + # test keep_ratio=False transform = dict( type='Resize', img_scale=(1280, 800), multiscale_mode='value', keep_ratio=False) resize_module = build_from_cfg(transform, PIPELINES) - results = resize_module(results) - assert results['img_shape'] == (800, 1280, 3) + resized_results = resize_module(results.copy()) + assert resized_results['img_shape'] == (800, 1280, 3) + + # test multiscale_mode='range' + transform = dict( + type='Resize', + img_scale=[(1333, 400), (1333, 1200)], + multiscale_mode='range', + keep_ratio=True) + resize_module = build_from_cfg(transform, PIPELINES) + resized_results = resize_module(results.copy()) + assert max(resized_results['img_shape'][:2]) <= 1333 + assert min(resized_results['img_shape'][:2]) >= 400 + assert min(resized_results['img_shape'][:2]) <= 1200 + + # test multiscale_mode='value' + transform = dict( + type='Resize', + img_scale=[(1333, 800), (1333, 400)], + multiscale_mode='value', + keep_ratio=True) + resize_module = build_from_cfg(transform, PIPELINES) + resized_results = resize_module(results.copy()) + assert resized_results['img_shape'] in [(750, 1333, 3), (400, 711, 3)] + + # test multiscale_mode='range' + transform = dict( + type='Resize', + img_scale=(1333, 800), + ratio_range=(0.9, 1.1), + keep_ratio=True) + resize_module = build_from_cfg(transform, PIPELINES) + resized_results = resize_module(results.copy()) + assert max(resized_results['img_shape'][:2]) <= 1333 * 1.1 + assert min(resized_results['img_shape'][:2]) >= 800 * 0.9 + assert min(resized_results['img_shape'][:2]) <= 800 * 1.1 def test_flip(): @@ -188,3 +223,22 @@ def test_normalize(): std = np.array(img_norm_cfg['std']) converted_img = (original_img[..., ::-1] - mean) / std assert np.allclose(results['img'], converted_img) + + +def test_seg_rescale(): + results = dict() + seg = np.array( + Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + h, w = seg.shape + + transform = dict(type='SegRescale', scale_factor=1. / 2) + rescale_module = build_from_cfg(transform, PIPELINES) + rescale_results = rescale_module(results.copy()) + assert rescale_results['gt_semantic_seg'].shape == (h // 2, w // 2) + + transform = dict(type='SegRescale', scale_factor=1) + rescale_module = build_from_cfg(transform, PIPELINES) + rescale_results = rescale_module(results.copy()) + assert rescale_results['gt_semantic_seg'].shape == (h, w) diff --git a/tests/test_models/test_forward.py b/tests/test_models/test_forward.py index f6636e6079..620b82e64d 100644 --- a/tests/test_models/test_forward.py +++ b/tests/test_models/test_forward.py @@ -4,6 +4,7 @@ from unittest.mock import patch import numpy as np +import pytest import torch import torch.nn as nn from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm @@ -112,9 +113,10 @@ def test_ann_forward(): def test_ccnet_forward(): - if torch.cuda.is_available(): - _test_encoder_decoder_forward( - 'ccnet/ccnet_r50-d8_512x1024_40k_cityscapes.py') + if not torch.cuda.is_available(): + pytest.skip('CCNet requires CUDA') + _test_encoder_decoder_forward( + 'ccnet/ccnet_r50-d8_512x1024_40k_cityscapes.py') def test_danet_forward(): diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index 59bd06b9a3..935239438f 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -6,6 +6,7 @@ from mmcv.utils.parrots_wrapper import SyncBatchNorm from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead, + DepthwiseSeparableASPPHead, EncHead, FCNHead, GCHead, NLHead, OCRHead, PSAHead, PSPHead, UPerHead) from mmseg.models.decode_heads.decode_head import BaseDecodeHead @@ -260,6 +261,7 @@ def test_psa_head(): norm_cfg=dict(type='SyncBN')) assert _conv_has_norm(head, sync_bn=True) + # test 'bi-direction' psa_type inputs = [torch.randn(1, 32, 39, 39)] head = PSAHead( in_channels=32, channels=16, num_classes=19, mask_size=(39, 39)) @@ -268,6 +270,87 @@ def test_psa_head(): outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 39, 39) + # test 'bi-direction' psa_type, shrink_factor=1 + inputs = [torch.randn(1, 32, 39, 39)] + head = PSAHead( + in_channels=32, + channels=16, + num_classes=19, + mask_size=(39, 39), + shrink_factor=1) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 39, 39) + + # test 'bi-direction' psa_type with soft_max + inputs = [torch.randn(1, 32, 39, 39)] + head = PSAHead( + in_channels=32, + channels=16, + num_classes=19, + mask_size=(39, 39), + psa_softmax=True) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 39, 39) + + # test 'collect' psa_type + inputs = [torch.randn(1, 32, 39, 39)] + head = PSAHead( + in_channels=32, + channels=16, + num_classes=19, + mask_size=(39, 39), + psa_type='collect') + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 39, 39) + + # test 'collect' psa_type, shrink_factor=1 + inputs = [torch.randn(1, 32, 39, 39)] + head = PSAHead( + in_channels=32, + channels=16, + num_classes=19, + mask_size=(39, 39), + shrink_factor=1, + psa_type='collect') + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 39, 39) + + # test 'collect' psa_type, shrink_factor=1, compact=True + inputs = [torch.randn(1, 32, 39, 39)] + head = PSAHead( + in_channels=32, + channels=16, + num_classes=19, + mask_size=(39, 39), + psa_type='collect', + shrink_factor=1, + compact=True) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 39, 39) + + # test 'distribute' psa_type + inputs = [torch.randn(1, 32, 39, 39)] + head = PSAHead( + in_channels=32, + channels=16, + num_classes=19, + mask_size=(39, 39), + psa_type='distribute') + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 39, 39) + def test_gc_head(): head = GCHead(in_channels=32, channels=16, num_classes=19) @@ -295,11 +378,12 @@ def test_cc_head(): head = CCHead(in_channels=32, channels=16, num_classes=19) assert len(head.convs) == 2 assert hasattr(head, 'cca') - if torch.cuda.is_available(): - inputs = [torch.randn(1, 32, 45, 45)] - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 45, 45) + if not torch.cuda.is_available(): + pytest.skip('CCHead requires CUDA') + inputs = [torch.randn(1, 32, 45, 45)] + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) def test_uper_head(): @@ -353,8 +437,11 @@ def test_da_head(): if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) outputs = head(inputs) + assert isinstance(outputs, tuple) and len(outputs) == 3 for output in outputs: assert output.shape == (1, head.num_classes, 45, 45) + test_output = head.forward_test(inputs, None, None) + assert test_output.shape == (1, head.num_classes, 45, 45) def test_ocr_head(): @@ -369,3 +456,86 @@ def test_ocr_head(): prev_output = fcn_head(inputs) output = ocr_head(inputs, prev_output) assert output.shape == (1, ocr_head.num_classes, 45, 45) + + +def test_enc_head(): + # with se_loss, w.o. lateral + inputs = [torch.randn(1, 32, 21, 21)] + head = EncHead( + in_channels=[32], channels=16, num_classes=19, in_index=[-1]) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert isinstance(outputs, tuple) and len(outputs) == 2 + assert outputs[0].shape == (1, head.num_classes, 21, 21) + assert outputs[1].shape == (1, head.num_classes) + + # w.o se_loss, w.o. lateral + inputs = [torch.randn(1, 32, 21, 21)] + head = EncHead( + in_channels=[32], + channels=16, + use_se_loss=False, + num_classes=19, + in_index=[-1]) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 21, 21) + + # with se_loss, with lateral + inputs = [torch.randn(1, 16, 45, 45), torch.randn(1, 32, 21, 21)] + head = EncHead( + in_channels=[16, 32], + channels=16, + add_lateral=True, + num_classes=19, + in_index=[-2, -1]) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert isinstance(outputs, tuple) and len(outputs) == 2 + assert outputs[0].shape == (1, head.num_classes, 21, 21) + assert outputs[1].shape == (1, head.num_classes) + test_output = head.forward_test(inputs, None, None) + assert test_output.shape == (1, head.num_classes, 21, 21) + + +def test_dw_aspp_head(): + + # test w.o. c1 + inputs = [torch.randn(1, 32, 45, 45)] + head = DepthwiseSeparableASPPHead( + c1_in_channels=0, + c1_channels=0, + in_channels=32, + channels=16, + num_classes=19, + dilations=(1, 12, 24)) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.c1_bottleneck is None + assert head.aspp_modules[0].conv.dilation == (1, 1) + assert head.aspp_modules[1].depthwise_conv.dilation == (12, 12) + assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) + + # test with c1 + inputs = [torch.randn(1, 8, 45, 45), torch.randn(1, 32, 21, 21)] + head = DepthwiseSeparableASPPHead( + c1_in_channels=8, + c1_channels=4, + in_channels=32, + channels=16, + num_classes=19, + dilations=(1, 12, 24)) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.c1_bottleneck.in_channels == 8 + assert head.c1_bottleneck.out_channels == 4 + assert head.aspp_modules[0].conv.dilation == (1, 1) + assert head.aspp_modules[1].depthwise_conv.dilation == (12, 12) + assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45) From 1cb4481a76330a13817934b74bb2830c4a1ec95f Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Wed, 8 Jul 2020 09:45:30 +0000 Subject: [PATCH 2/6] fixed typo --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6bde4487ab..830a982544 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -58,6 +58,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install CUDA + if: ${{matrix.torch == '1.5.0+cu101'}} run: | export INSTALLER=cuda-repo-${UBUNTU_VERSION}_${CUDA}_amd64.deb wget http://developer.download.nvidia.com/compute/cuda/repos/${UBUNTU_VERSION}/x86_64/${INSTALLER} @@ -71,10 +72,9 @@ jobs: export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${CUDA_HOME}/include:${LD_LIBRARY_PATH} export PATH=${CUDA_HOME}/bin:${PATH} sudo apt-get install -y ninja-build - - if: ${{matrix.torch == '1.5.0+cu101'}} - name: Install Pillow - run: pip install Pillow==6.2.2 if: ${{matrix.torchvision == '0.4.2+cpu'}} + run: pip install Pillow==6.2.2 - name: Install PyTorch run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html - name: Install mmseg dependencies From ea061422b89a5c606f03d3351dc66ec7b9de358f Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Wed, 8 Jul 2020 09:49:58 +0000 Subject: [PATCH 3/6] add test data --- tests/data/pseudo_dataset/splits/train.txt | 4 ++++ tests/data/pseudo_dataset/splits/val.txt | 1 + 2 files changed, 5 insertions(+) create mode 100644 tests/data/pseudo_dataset/splits/train.txt create mode 100644 tests/data/pseudo_dataset/splits/val.txt diff --git a/tests/data/pseudo_dataset/splits/train.txt b/tests/data/pseudo_dataset/splits/train.txt new file mode 100644 index 0000000000..9e25ab0266 --- /dev/null +++ b/tests/data/pseudo_dataset/splits/train.txt @@ -0,0 +1,4 @@ +00000 +00001 +00002 +00003 diff --git a/tests/data/pseudo_dataset/splits/val.txt b/tests/data/pseudo_dataset/splits/val.txt new file mode 100644 index 0000000000..59dd536625 --- /dev/null +++ b/tests/data/pseudo_dataset/splits/val.txt @@ -0,0 +1 @@ +00004 From 57bb9bcaf3d0c113687637dbb4b7fef3a137098a Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Wed, 8 Jul 2020 10:02:29 +0000 Subject: [PATCH 4/6] fixed cuda python version --- .github/workflows/build.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 830a982544..e0aef9538c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -45,11 +45,12 @@ jobs: torchvision: 0.4.2+cpu - torch: 1.5.0+cpu torchvision: 0.6.0+cpu - - python-version: 3.8 - torch: 1.5.0+cpu + - torch: 1.5.0+cpu torchvision: 0.6.0+cpu + python-version: 3.8 - torch: 1.5.0+cu101 torchvision: 0.6.0+cu101 + python-version: 3.7 steps: - uses: actions/checkout@v2 From bd1b180e1f27c81c6c13ff116c5e844a408c88ab Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Wed, 8 Jul 2020 10:26:38 +0000 Subject: [PATCH 5/6] fixed resize test --- tests/test_data/test_transform.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index 9c11b7032f..7a1ca0dde3 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -91,8 +91,6 @@ def test_resize(): resize_module = build_from_cfg(transform, PIPELINES) resized_results = resize_module(results.copy()) assert max(resized_results['img_shape'][:2]) <= 1333 * 1.1 - assert min(resized_results['img_shape'][:2]) >= 800 * 0.9 - assert min(resized_results['img_shape'][:2]) <= 800 * 1.1 def test_flip(): From 884582bfae8c850ebc1293888490e54e0118acdb Mon Sep 17 00:00:00 2001 From: Jiarui XU Date: Wed, 8 Jul 2020 12:55:03 +0000 Subject: [PATCH 6/6] remove debug ingo --- mmseg/datasets/custom.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/custom.py index 1d0d7a49e3..92d17c5252 100644 --- a/mmseg/datasets/custom.py +++ b/mmseg/datasets/custom.py @@ -95,9 +95,6 @@ def __init__(self, if not (self.split is None or osp.isabs(self.split)): self.split = osp.join(self.data_root, self.split) - print('loadding ann') - print(self.img_dir) - print(self.ann_dir) # load annotations self.img_infos = self.load_annotations(self.img_dir, self.img_suffix, self.ann_dir,