Skip to content

Commit

Permalink
[Fix] Fix bug when loading class name form file in custom dataset (op…
Browse files Browse the repository at this point in the history
…en-mmlab#923)

* [Fix] open-mmlab#916 expection string type classes

* add unittests for string path classes

* fix double quote string in test_dataset.py

* move the import to the top of the file

* fix isort lint error

fix isort lint error when move the import to the top of the file
  • Loading branch information
ShoupingShan authored Oct 7, 2021
1 parent 1ce4904 commit adb1cd3
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def get_classes_and_palette(self, classes=None, palette=None):
raise ValueError(f'Unsupported type {type(classes)} of classes.')

if self.CLASSES:
if not set(classes).issubset(self.CLASSES):
if not set(class_names).issubset(self.CLASSES):
raise ValueError('classes is not a subset of CLASSES.')

# dictionary, its keys are the old label ids and its values
Expand All @@ -330,7 +330,7 @@ def get_classes_and_palette(self, classes=None, palette=None):
if c not in class_names:
self.label_map[i] = -1
else:
self.label_map[i] = classes.index(c)
self.label_map[i] = class_names.index(c)

palette = self.get_palette_for_custom_classes(class_names, palette)

Expand Down
33 changes: 33 additions & 0 deletions tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import shutil
import tempfile
from typing import Generator
from unittest.mock import MagicMock, patch

Expand All @@ -26,6 +28,37 @@ def test_classes():
get_classes('unsupported')


def test_classes_file_path():
tmp_file = tempfile.NamedTemporaryFile()
classes_path = f'{tmp_file.name}.txt'
train_pipeline = [dict(type='LoadImageFromFile')]
kwargs = dict(pipeline=train_pipeline, img_dir='./', classes=classes_path)

# classes.txt with full categories
categories = get_classes('cityscapes')
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
assert list(CityscapesDataset(**kwargs).CLASSES) == categories

# classes.txt with sub categories
categories = ['road', 'sidewalk', 'building']
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
assert list(CityscapesDataset(**kwargs).CLASSES) == categories

# classes.txt with unknown categories
categories = ['road', 'sidewalk', 'unknown']
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))

with pytest.raises(ValueError):
CityscapesDataset(**kwargs)

tmp_file.close()
os.remove(classes_path)
assert not osp.exists(classes_path)


def test_palette():
assert CityscapesDataset.PALETTE == get_palette('cityscapes')
assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette(
Expand Down

0 comments on commit adb1cd3

Please sign in to comment.