Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VOCSegmentation, VOCDetection, linting passing, examples. #663

Merged
merged 8 commits into from
Dec 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,15 @@ Flickr
.. autoclass:: Flickr30k
:members: __getitem__
:special-members:

VOC
~~~~~~


.. autoclass:: VOCSegmentation
:members: __getitem__
:special-members:

.. autoclass:: VOCDetection
:members: __getitem__
:special-members:
4 changes: 3 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from .omniglot import Omniglot
from .sbu import SBU
from .flickr import Flickr8k, Flickr30k
from .voc import VOCSegmentation, VOCDetection

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k')
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
'VOCSegmentation', 'VOCDetection')
244 changes: 244 additions & 0 deletions torchvision/datasets/voc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import os
import sys
import tarfile
import collections
import torch.utils.data as data
if sys.version_info[0] == 2:
import xml.etree.cElementTree as ET
else:
import xml.etree.ElementTree as ET

from PIL import Image
from .utils import download_url, check_integrity

DATASET_YEAR_DICT = {
bpinaya marked this conversation as resolved.
Show resolved Hide resolved
'2012': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
'filename': 'VOCtrainval_11-May-2012.tar',
'md5': '6cd6e144f989b92b3379bac3b3de84fd',
'base_dir': 'VOCdevkit/VOC2012'
},
'2011': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
'filename': 'VOCtrainval_25-May-2011.tar',
'md5': '6c3384ef61512963050cb5d687e5bf1e',
'base_dir': 'TrainVal/VOCdevkit/VOC2011'
},
'2010': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
'filename': 'VOCtrainval_03-May-2010.tar',
'md5': 'da459979d0c395079b5c75ee67908abb',
'base_dir': 'VOCdevkit/VOC2010'
},
'2009': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
'filename': 'VOCtrainval_11-May-2009.tar',
'md5': '59065e4b188729180974ef6572f6a212',
'base_dir': 'VOCdevkit/VOC2009'
},
'2008': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
'filename': 'VOCtrainval_11-May-2012.tar',
'md5': '2629fa636546599198acfcfbfcf1904a',
'base_dir': 'VOCdevkit/VOC2008'
},
'2007': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
'filename': 'VOCtrainval_06-Nov-2007.tar',
'md5': 'c52e279531787c972589f7e41ab4ae64',
'base_dir': 'VOCdevkit/VOC2007'
}
}


class VOCSegmentation(data.Dataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.

Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

def __init__(self,
root,
year='2012',
image_set='train',
download=False,
transform=None,
target_transform=None):
self.root = root
self.year = year
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5']
self.transform = transform
self.target_transform = target_transform
self.image_set = image_set
base_dir = DATASET_YEAR_DICT[year]['base_dir']
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, 'JPEGImages')
mask_dir = os.path.join(voc_root, 'SegmentationClass')

if download:
download_extract(self.url, self.root, self.filename, self.md5)

if not os.path.isdir(voc_root):
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')

split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')

if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val"')

with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]

self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert (len(self.images) == len(self.masks))

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert('RGB')
target = Image.open(self.masks[index])

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
return len(self.images)


class VOCDetection(data.Dataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.

Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
(default: alphabetic indexing of VOC's 20 classes).
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, required): A function/transform that takes in the
target and transforms it.
"""

def __init__(self,
root,
year='2012',
image_set='train',
download=False,
transform=None,
target_transform=None):
self.root = root
self.year = year
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5']
self.transform = transform
self.target_transform = target_transform
self.image_set = image_set

base_dir = DATASET_YEAR_DICT[year]['base_dir']
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, 'JPEGImages')
annotation_dir = os.path.join(voc_root, 'Annotations')

if download:
download_extract(self.url, self.root, self.filename, self.md5)

if not os.path.isdir(voc_root):
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

splits_dir = os.path.join(voc_root, 'ImageSets/Main')

split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')

if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val" or a valid'
'image_set from the VOC ImageSets/Main folder.')

with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]

self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names]
assert (len(self.images) == len(self.annotations))

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is a dictionary of the XML tree.
"""
img = Image.open(self.images[index]).convert('RGB')
target = self.parse_voc_xml(
ET.parse(self.annotations[index]).getroot())

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
return len(self.images)

def parse_voc_xml(self, node):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, thanks!
Just to double check, did you try parsing all the images in say VOC2012? I know that some images have a single object and that might require some special handling, just want to verify that this is indeed being taken care here.

voc_dict = {}
children = list(node)
if children:
def_dic = collections.defaultdict(list)
for dc in map(self.parse_voc_xml, children):
for ind, v in dc.items():
def_dic[ind].append(v)
voc_dict = {
node.tag:
{ind: v[0] if len(v) == 1 else v
for ind, v in def_dic.items()}
}
if node.text:
text = node.text.strip()
if not children:
voc_dict[node.tag] = text
return voc_dict


def download_extract(url, root, filename, md5):
download_url(url, root, filename, md5)
with tarfile.open(os.path.join(root, filename), "r") as tar:
tar.extractall(path=root)