From 4bfc1182f76a2efe4d547c437f42283ec33ae81e Mon Sep 17 00:00:00 2001 From: Benjamin Pinaya Date: Mon, 19 Nov 2018 09:14:17 +0100 Subject: [PATCH 1/7] VOC Dataset, linted, flak8 passing, samples on gist. --- torchvision/datasets/__init__.py | 3 +- torchvision/datasets/voc.py | 264 +++++++++++++++++++++++++++++++ 2 files changed, 266 insertions(+), 1 deletion(-) create mode 100644 torchvision/datasets/voc.py diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index e2d2801216a..662a682712d 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -9,10 +9,11 @@ from .fakedata import FakeData from .semeion import SEMEION from .omniglot import Omniglot +from .voc import VOCSegmentation, VOCDetection __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'DatasetFolder', 'FakeData', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', - 'Omniglot') + 'Omniglot', 'VOCSegmentation', 'VOCDetection') diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py new file mode 100644 index 00000000000..167664dbe3c --- /dev/null +++ b/torchvision/datasets/voc.py @@ -0,0 +1,264 @@ +import os +import sys +import tarfile +import torch.utils.data as data +if sys.version_info[0] == 2: + import xml.etree.cElementTree as ET +else: + import xml.etree.ElementTree as ET + +from PIL import Image +from .utils import download_url, check_integrity + +VOC_CLASSES = [ + 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', + 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' +] +DATASET_YEAR_DICT = { + '2012': [ + 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', + 'VOCtrainval_11-May-2012.tar', '6cd6e144f989b92b3379bac3b3de84fd', + ' VOCdevkit/VOC2012' + ], + '2011': [ + 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', + 'VOCtrainval_25-May-2011.tar', '6c3384ef61512963050cb5d687e5bf1e', + 'TrainVal/VOCdevkit/VOC2011' + ], + '2010': [ + 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', + 'VOCtrainval_03-May-2010.tar', 'da459979d0c395079b5c75ee67908abb', + 'VOCdevkit/VOC2010' + ], + '2009': [ + 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', + 'VOCtrainval_11-May-2009.tar', '59065e4b188729180974ef6572f6a212', + 'VOCdevkit/VOC2009' + ], + '2008': [ + 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', + 'VOCtrainval_14-Jul-2008.tar', '2629fa636546599198acfcfbfcf1904a', + 'VOCdevkit/VOC2008' + ], + '2007': [ + 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', + 'VOCtrainval_06-Nov-2007.tar', 'c52e279531787c972589f7e41ab4ae64', + 'VOCdevkit/VOC2007' + ] +} + + +class VOCSegmentation(data.Dataset): + """`Pascal VOC `_ Segmentation Dataset. + + Args: + root (string): Root directory of the VOC Dataset. + year (string, optional): The dataset year, supports years 2007 to 2012. + image_set (string, optional): Select the image_set to use, ``train, trainval or val`` + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + def __init__(self, + root, + year='2012', + image_set='train', + download=False, + transform=None, + target_transform=None): + self.root = root + self.year = year + self.url = DATASET_YEAR_DICT[year][0] + self.filename = DATASET_YEAR_DICT[year][1] + self.md5 = DATASET_YEAR_DICT[year][2] + self.transform = transform + self.target_transform = target_transform + self.image_set = image_set + _base_dir = DATASET_YEAR_DICT[year][3] + _voc_root = os.path.join(self.root, _base_dir) + _image_dir = os.path.join(_voc_root, 'JPEGImages') + _mask_dir = os.path.join(_voc_root, 'SegmentationClass') + + if download: + download_extract(self.url, self.root, self.filename, self.md5) + + if not os.path.isdir(_voc_root): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + _splits_dir = os.path.join(_voc_root, 'ImageSets/Segmentation') + + _split_f = os.path.join(_splits_dir, image_set.rstrip('\n') + '.txt') + + if not os.path.exists(_split_f): + raise ValueError( + 'Wrong image_set entered! Please use image_set="train" ' + 'or image_set="trainval" or image_set="val"') + + self.images = [] + self.masks = [] + with open(os.path.join(_split_f), "r") as lines: + for line in lines: + _image = os.path.join(_image_dir, line.rstrip('\n') + ".jpg") + _mask = os.path.join(_mask_dir, line.rstrip('\n') + ".png") + assert os.path.isfile(_image) + assert os.path.isfile(_mask) + self.images.append(_image) + self.masks.append(_mask) + + assert (len(self.images) == len(self.masks)) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is the image segmentation. + """ + _img = Image.open(self.images[index]).convert('RGB') + _target = Image.open(self.masks[index]) + + if self.transform is not None: + _img = self.transform(_img) + + if self.target_transform is not None: + _target = self.target_transform(_target) + + return _img, _target + + def __len__(self): + return len(self.images) + + +class VOCDetection(data.Dataset): + """`Pascal VOC `_ Detection Dataset. + + Args: + root (string): Root directory of the VOC Dataset. + year (string, optional): The dataset year, supports years 2007 to 2012. + image_set (string, optional): Select the image_set to use, ``train, trainval or val`` + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + class_to_ind (dict, optional): dictionary lookup of classnames -> indexes + (default: alphabetic indexing of VOC's 20 classes). + keep_difficult (boolean, optional): keep difficult instances or not. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + def __init__(self, + root, + year='2012', + image_set='train', + download=False, + class_to_ind=None, + keep_difficult=False, + transform=None, + target_transform=None): + self.root = root + self.year = year + self.url = DATASET_YEAR_DICT[year][0] + self.filename = DATASET_YEAR_DICT[year][1] + self.md5 = DATASET_YEAR_DICT[year][2] + self.transform = transform + self.target_transform = target_transform + self.image_set = image_set + self.class_to_ind = class_to_ind or dict( + zip(VOC_CLASSES, range(len(VOC_CLASSES)))) + self.keep_difficult = keep_difficult + _base_dir = DATASET_YEAR_DICT[year][3] + _voc_root = os.path.join(self.root, _base_dir) + _image_dir = os.path.join(_voc_root, 'JPEGImages') + _annotation_dir = os.path.join(_voc_root, 'Annotations') + + if download: + download_extract(self.url, self.root, self.filename, self.md5) + + if not os.path.isdir(_voc_root): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + _splits_dir = os.path.join(_voc_root, 'ImageSets/Main') + + _split_f = os.path.join(_splits_dir, image_set.rstrip('\n') + '.txt') + + if not os.path.exists(_split_f): + raise ValueError( + 'Wrong image_set entered! Please use image_set="train" ' + 'or image_set="trainval" or image_set="val" or a valid' + 'image_set from the VOC ImageSets/Main folder.') + + self.images = [] + self.annotations = [] + with open(os.path.join(_split_f), "r") as lines: + for line in lines: + _image = os.path.join(_image_dir, line.rstrip('\n') + ".jpg") + _annotation = os.path.join(_annotation_dir, + line.rstrip('\n') + ".xml") + assert os.path.isfile(_image) + assert os.path.isfile(_annotation) + self.images.append(_image) + self.annotations.append(_annotation) + + assert (len(self.images) == len(self.annotations)) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is a list of bounding boxes of + relative coordinates like``[[xmin, ymin, xmax, ymax, ind], [...], ...]``. + """ + _img = Image.open(self.images[index]).convert('RGB') + _target = self._get_bboxes(ET.parse(self.annotations[index]).getroot()) + + if self.transform is not None: + _img = self.transform(_img) + + if self.target_transform is not None: + _target = self.target_transform(_target) + + return _img, _target + + def __len__(self): + return len(self.images) + + def _get_bboxes(self, target): + res = [] + for obj in target.iter('object'): + difficult = int(obj.find('difficult').text) == 1 + if not self.keep_difficult and difficult: + continue + name = obj.find('name').text.lower().strip() + bbox = obj.find('bndbox') + width = int(target.find('size').find('width').text) + height = int(target.find('size').find('height').text) + bndbox = [] + for i, cur_bb in enumerate(bbox): + bb_sz = int(cur_bb.text) - 1 + # relative coordinates + bb_sz = bb_sz / width if i % 2 == 0 else bb_sz / height + bndbox.append(bb_sz) + + label_ind = self.class_to_ind[name] + bndbox.append(label_ind) + res.append(bndbox) # [xmin, ymin, xmax, ymax, ind] + return res + + +def download_extract(url, root, filename, md5): + download_url(url, root, filename, md5) + with tarfile.open(os.path.join(root, filename), "r") as tar: + tar.extractall(path=root) From b9d9ade72a812d903f9957e5ee5da9329f0f566b Mon Sep 17 00:00:00 2001 From: Benjamin Pinaya Date: Tue, 20 Nov 2018 09:57:52 +0100 Subject: [PATCH 2/7] Double backtick on values. --- torchvision/datasets/voc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py index 167664dbe3c..3c749c2ed9a 100644 --- a/torchvision/datasets/voc.py +++ b/torchvision/datasets/voc.py @@ -55,7 +55,7 @@ class VOCSegmentation(data.Dataset): Args: root (string): Root directory of the VOC Dataset. year (string, optional): The dataset year, supports years 2007 to 2012. - image_set (string, optional): Select the image_set to use, ``train, trainval or val`` + image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. @@ -143,7 +143,7 @@ class VOCDetection(data.Dataset): Args: root (string): Root directory of the VOC Dataset. year (string, optional): The dataset year, supports years 2007 to 2012. - image_set (string, optional): Select the image_set to use, ``train, trainval or val`` + image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. From 25145685de175165cefd5ada35a4d571af1dd500 Mon Sep 17 00:00:00 2001 From: Ellis Brown Date: Wed, 28 Nov 2018 11:10:57 +0100 Subject: [PATCH 3/7] Apply suggestions from code review Add suggestions from @ellisbrown, using dict of dicts instead of array index. Co-Authored-By: bpinaya --- torchvision/datasets/voc.py | 53 +++++++++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 8 deletions(-) diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py index 3c749c2ed9a..8cfc582c88d 100644 --- a/torchvision/datasets/voc.py +++ b/torchvision/datasets/voc.py @@ -16,6 +16,43 @@ 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' ] DATASET_YEAR_DICT = { + '2012': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', + 'filename': 'VOCtrainval_11-May-2012.tar', + 'md5': '6cd6e144f989b92b3379bac3b3de84fd', + 'base_dir': 'VOCdevkit/VOC2012' + }, + '2011': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', + 'filename': 'VOCtrainval_25-May-2011.tar', + 'md5': '6c3384ef61512963050cb5d687e5bf1e', + 'base_dir': 'TrainVal/VOCdevkit/VOC2011' + }, + '2010': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', + 'filename': 'VOCtrainval_03-May-2010.tar, + 'md5': 'da459979d0c395079b5c75ee67908abb', + 'base_dir': 'VOCdevkit/VOC2010' + }, + '2009': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', + 'filename': 'VOCtrainval_11-May-2009.tar'', + 'md5': '59065e4b188729180974ef6572f6a212', + 'base_dir': 'VOCdevkit/VOC2009' + }, + '2008': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', + 'filename': 'VOCtrainval_11-May-2012.tar', + 'md5': '2629fa636546599198acfcfbfcf1904a', + 'base_dir': 'VOCdevkit/VOC2008' + }, + '2007': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', + 'filename': 'VOCtrainval_06-Nov-2007.tar', + 'md5': 'c52e279531787c972589f7e41ab4ae64', + 'base_dir': 'VOCdevkit/VOC2007' + } +} '2012': [ 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', 'VOCtrainval_11-May-2012.tar', '6cd6e144f989b92b3379bac3b3de84fd', @@ -74,13 +111,13 @@ def __init__(self, target_transform=None): self.root = root self.year = year - self.url = DATASET_YEAR_DICT[year][0] - self.filename = DATASET_YEAR_DICT[year][1] - self.md5 = DATASET_YEAR_DICT[year][2] + self.url = DATASET_YEAR_DICT[year]['url'] + self.filename = DATASET_YEAR_DICT[year]['filename'] + self.md5 = DATASET_YEAR_DICT[year]['md5'] self.transform = transform self.target_transform = target_transform self.image_set = image_set - _base_dir = DATASET_YEAR_DICT[year][3] + _base_dir = DATASET_YEAR_DICT[year]['base_dir'] _voc_root = os.path.join(self.root, _base_dir) _image_dir = os.path.join(_voc_root, 'JPEGImages') _mask_dir = os.path.join(_voc_root, 'SegmentationClass') @@ -167,16 +204,16 @@ def __init__(self, target_transform=None): self.root = root self.year = year - self.url = DATASET_YEAR_DICT[year][0] - self.filename = DATASET_YEAR_DICT[year][1] - self.md5 = DATASET_YEAR_DICT[year][2] + self.url = DATASET_YEAR_DICT[year]['url'] + self.filename = DATASET_YEAR_DICT[year]['filename'] + self.md5 = DATASET_YEAR_DICT[year]['md5'] self.transform = transform self.target_transform = target_transform self.image_set = image_set self.class_to_ind = class_to_ind or dict( zip(VOC_CLASSES, range(len(VOC_CLASSES)))) self.keep_difficult = keep_difficult - _base_dir = DATASET_YEAR_DICT[year][3] + _base_dir = DATASET_YEAR_DICT[year]['base_dir'] _voc_root = os.path.join(self.root, _base_dir) _image_dir = os.path.join(_voc_root, 'JPEGImages') _annotation_dir = os.path.join(_voc_root, 'Annotations') From c67cfb18b8102f5b8039ef62d0f8c7a193f64fca Mon Sep 17 00:00:00 2001 From: Benjamin Pinaya Date: Wed, 28 Nov 2018 11:27:25 +0100 Subject: [PATCH 4/7] Fixed errors with the new comments. --- torchvision/datasets/voc.py | 35 ++--------------------------------- 1 file changed, 2 insertions(+), 33 deletions(-) diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py index 8cfc582c88d..d05f4eb537e 100644 --- a/torchvision/datasets/voc.py +++ b/torchvision/datasets/voc.py @@ -30,13 +30,13 @@ }, '2010': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', - 'filename': 'VOCtrainval_03-May-2010.tar, + 'filename': 'VOCtrainval_03-May-2010.tar', 'md5': 'da459979d0c395079b5c75ee67908abb', 'base_dir': 'VOCdevkit/VOC2010' }, '2009': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', - 'filename': 'VOCtrainval_11-May-2009.tar'', + 'filename': 'VOCtrainval_11-May-2009.tar', 'md5': '59065e4b188729180974ef6572f6a212', 'base_dir': 'VOCdevkit/VOC2009' }, @@ -53,37 +53,6 @@ 'base_dir': 'VOCdevkit/VOC2007' } } - '2012': [ - 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', - 'VOCtrainval_11-May-2012.tar', '6cd6e144f989b92b3379bac3b3de84fd', - ' VOCdevkit/VOC2012' - ], - '2011': [ - 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', - 'VOCtrainval_25-May-2011.tar', '6c3384ef61512963050cb5d687e5bf1e', - 'TrainVal/VOCdevkit/VOC2011' - ], - '2010': [ - 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', - 'VOCtrainval_03-May-2010.tar', 'da459979d0c395079b5c75ee67908abb', - 'VOCdevkit/VOC2010' - ], - '2009': [ - 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', - 'VOCtrainval_11-May-2009.tar', '59065e4b188729180974ef6572f6a212', - 'VOCdevkit/VOC2009' - ], - '2008': [ - 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', - 'VOCtrainval_14-Jul-2008.tar', '2629fa636546599198acfcfbfcf1904a', - 'VOCdevkit/VOC2008' - ], - '2007': [ - 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', - 'VOCtrainval_06-Nov-2007.tar', 'c52e279531787c972589f7e41ab4ae64', - 'VOCdevkit/VOC2007' - ] -} class VOCSegmentation(data.Dataset): From 6481f087f1714b2a155aef0a8afbd0c3888f2083 Mon Sep 17 00:00:00 2001 From: bpinaya Date: Wed, 5 Dec 2018 16:10:21 +0100 Subject: [PATCH 5/7] Added documentation on RST --- docs/source/datasets.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 230f9ae4627..5e751677f31 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -129,3 +129,15 @@ PhotoTour .. autoclass:: PhotoTour :members: __getitem__ :special-members: + +VOC +~~~~~~ + + +.. autoclass:: VOCSegmentation + :members: __getitem__ + :special-members: + +.. autoclass:: VOCDetection + :members: __getitem__ + :special-members: \ No newline at end of file From 15fc44a421857d1b3975bbed6919bae44d6bb6e7 Mon Sep 17 00:00:00 2001 From: Benjamin Pinaya Date: Thu, 6 Dec 2018 12:19:26 +0100 Subject: [PATCH 6/7] Removed getBB, added parse_voc_xml, variable naming change. --- torchvision/datasets/voc.py | 130 ++++++++++++++++-------------------- 1 file changed, 56 insertions(+), 74 deletions(-) diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py index d05f4eb537e..95a29fe59da 100644 --- a/torchvision/datasets/voc.py +++ b/torchvision/datasets/voc.py @@ -1,6 +1,7 @@ import os import sys import tarfile +import collections import torch.utils.data as data if sys.version_info[0] == 2: import xml.etree.cElementTree as ET @@ -86,38 +87,32 @@ def __init__(self, self.transform = transform self.target_transform = target_transform self.image_set = image_set - _base_dir = DATASET_YEAR_DICT[year]['base_dir'] - _voc_root = os.path.join(self.root, _base_dir) - _image_dir = os.path.join(_voc_root, 'JPEGImages') - _mask_dir = os.path.join(_voc_root, 'SegmentationClass') + base_dir = DATASET_YEAR_DICT[year]['base_dir'] + voc_root = os.path.join(self.root, base_dir) + image_dir = os.path.join(voc_root, 'JPEGImages') + mask_dir = os.path.join(voc_root, 'SegmentationClass') if download: download_extract(self.url, self.root, self.filename, self.md5) - if not os.path.isdir(_voc_root): + if not os.path.isdir(voc_root): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') - _splits_dir = os.path.join(_voc_root, 'ImageSets/Segmentation') + splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation') - _split_f = os.path.join(_splits_dir, image_set.rstrip('\n') + '.txt') + split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') - if not os.path.exists(_split_f): + if not os.path.exists(split_f): raise ValueError( 'Wrong image_set entered! Please use image_set="train" ' 'or image_set="trainval" or image_set="val"') - self.images = [] - self.masks = [] - with open(os.path.join(_split_f), "r") as lines: - for line in lines: - _image = os.path.join(_image_dir, line.rstrip('\n') + ".jpg") - _mask = os.path.join(_mask_dir, line.rstrip('\n') + ".png") - assert os.path.isfile(_image) - assert os.path.isfile(_mask) - self.images.append(_image) - self.masks.append(_mask) + with open(os.path.join(split_f), "r") as f: + file_names = [x.strip() for x in f.readlines()] + self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] + self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] assert (len(self.images) == len(self.masks)) def __getitem__(self, index): @@ -128,16 +123,16 @@ def __getitem__(self, index): Returns: tuple: (image, target) where target is the image segmentation. """ - _img = Image.open(self.images[index]).convert('RGB') - _target = Image.open(self.masks[index]) + img = Image.open(self.images[index]).convert('RGB') + target = Image.open(self.masks[index]) if self.transform is not None: - _img = self.transform(_img) + img = self.transform(img) if self.target_transform is not None: - _target = self.target_transform(_target) + target = self.target_transform(target) - return _img, _target + return img, target def __len__(self): return len(self.images) @@ -155,10 +150,9 @@ class VOCDetection(data.Dataset): downloaded again. class_to_ind (dict, optional): dictionary lookup of classnames -> indexes (default: alphabetic indexing of VOC's 20 classes). - keep_difficult (boolean, optional): keep difficult instances or not. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform that takes in the + target_transform (callable, required): A function/transform that takes in the target and transforms it. """ @@ -168,7 +162,6 @@ def __init__(self, image_set='train', download=False, class_to_ind=None, - keep_difficult=False, transform=None, target_transform=None): self.root = root @@ -181,41 +174,33 @@ def __init__(self, self.image_set = image_set self.class_to_ind = class_to_ind or dict( zip(VOC_CLASSES, range(len(VOC_CLASSES)))) - self.keep_difficult = keep_difficult - _base_dir = DATASET_YEAR_DICT[year]['base_dir'] - _voc_root = os.path.join(self.root, _base_dir) - _image_dir = os.path.join(_voc_root, 'JPEGImages') - _annotation_dir = os.path.join(_voc_root, 'Annotations') + base_dir = DATASET_YEAR_DICT[year]['base_dir'] + voc_root = os.path.join(self.root, base_dir) + image_dir = os.path.join(voc_root, 'JPEGImages') + annotation_dir = os.path.join(voc_root, 'Annotations') if download: download_extract(self.url, self.root, self.filename, self.md5) - if not os.path.isdir(_voc_root): + if not os.path.isdir(voc_root): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') - _splits_dir = os.path.join(_voc_root, 'ImageSets/Main') + splits_dir = os.path.join(voc_root, 'ImageSets/Main') - _split_f = os.path.join(_splits_dir, image_set.rstrip('\n') + '.txt') + split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') - if not os.path.exists(_split_f): + if not os.path.exists(split_f): raise ValueError( 'Wrong image_set entered! Please use image_set="train" ' 'or image_set="trainval" or image_set="val" or a valid' 'image_set from the VOC ImageSets/Main folder.') - self.images = [] - self.annotations = [] - with open(os.path.join(_split_f), "r") as lines: - for line in lines: - _image = os.path.join(_image_dir, line.rstrip('\n') + ".jpg") - _annotation = os.path.join(_annotation_dir, - line.rstrip('\n') + ".xml") - assert os.path.isfile(_image) - assert os.path.isfile(_annotation) - self.images.append(_image) - self.annotations.append(_annotation) + with open(os.path.join(split_f), "r") as f: + file_names = [x.strip() for x in f.readlines()] + self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] + self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names] assert (len(self.images) == len(self.annotations)) def __getitem__(self, index): @@ -224,44 +209,41 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is a list of bounding boxes of - relative coordinates like``[[xmin, ymin, xmax, ymax, ind], [...], ...]``. + tuple: (image, target) where target is a dictionary of the XML tree. """ - _img = Image.open(self.images[index]).convert('RGB') - _target = self._get_bboxes(ET.parse(self.annotations[index]).getroot()) + img = Image.open(self.images[index]).convert('RGB') + target = self.parse_voc_xml( + ET.parse(self.annotations[index]).getroot()) if self.transform is not None: - _img = self.transform(_img) + img = self.transform(img) if self.target_transform is not None: - _target = self.target_transform(_target) + target = self.target_transform(target) - return _img, _target + return img, target def __len__(self): return len(self.images) - def _get_bboxes(self, target): - res = [] - for obj in target.iter('object'): - difficult = int(obj.find('difficult').text) == 1 - if not self.keep_difficult and difficult: - continue - name = obj.find('name').text.lower().strip() - bbox = obj.find('bndbox') - width = int(target.find('size').find('width').text) - height = int(target.find('size').find('height').text) - bndbox = [] - for i, cur_bb in enumerate(bbox): - bb_sz = int(cur_bb.text) - 1 - # relative coordinates - bb_sz = bb_sz / width if i % 2 == 0 else bb_sz / height - bndbox.append(bb_sz) - - label_ind = self.class_to_ind[name] - bndbox.append(label_ind) - res.append(bndbox) # [xmin, ymin, xmax, ymax, ind] - return res + def parse_voc_xml(self, node): + voc_dict = {} + children = list(node) + if children: + def_dic = collections.defaultdict(list) + for dc in map(self.parse_voc_xml, children): + for ind, v in dc.items(): + def_dic[ind].append(v) + voc_dict = { + node.tag: + {ind: v[0] if len(v) == 1 else v + for ind, v in def_dic.items()} + } + if node.text: + text = node.text.strip() + if not children: + voc_dict[node.tag] = text + return voc_dict def download_extract(url, root, filename, md5): From ad2e29d678315b655c9a8a7cb3f15d88785b723b Mon Sep 17 00:00:00 2001 From: Benjamin Pinaya Date: Thu, 6 Dec 2018 13:58:12 +0100 Subject: [PATCH 7/7] Removed unused variable, removed VOC_CLASSES, two new gists for test. --- torchvision/datasets/voc.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py index 95a29fe59da..f886d701bca 100644 --- a/torchvision/datasets/voc.py +++ b/torchvision/datasets/voc.py @@ -11,11 +11,6 @@ from PIL import Image from .utils import download_url, check_integrity -VOC_CLASSES = [ - 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', - 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', - 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' -] DATASET_YEAR_DICT = { '2012': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', @@ -148,7 +143,6 @@ class VOCDetection(data.Dataset): download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. - class_to_ind (dict, optional): dictionary lookup of classnames -> indexes (default: alphabetic indexing of VOC's 20 classes). transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` @@ -161,7 +155,6 @@ def __init__(self, year='2012', image_set='train', download=False, - class_to_ind=None, transform=None, target_transform=None): self.root = root @@ -172,8 +165,7 @@ def __init__(self, self.transform = transform self.target_transform = target_transform self.image_set = image_set - self.class_to_ind = class_to_ind or dict( - zip(VOC_CLASSES, range(len(VOC_CLASSES)))) + base_dir = DATASET_YEAR_DICT[year]['base_dir'] voc_root = os.path.join(self.root, base_dir) image_dir = os.path.join(voc_root, 'JPEGImages')