Skip to content

Commit

Permalink
Standardize str argument verification in datasets (#1167)
Browse files Browse the repository at this point in the history
* introduced function to verify str arguments

* flake8

* added FIXME to VOC

* Fixed error message

* added test for verify_str_arg

* cleanup todos

* added option for custom error message

* fix VOC

* fixed Caltech
  • Loading branch information
Philip Meier authored and fmassa committed Jul 26, 2019
1 parent d9830d8 commit 4886ccc
Show file tree
Hide file tree
Showing 12 changed files with 106 additions and 110 deletions.
5 changes: 5 additions & 0 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def test_extract_gzip(self):
data = nf.read()
self.assertEqual(data, 'this is the content')

def test_verify_str_arg(self):
self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
self.assertRaises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")


if __name__ == '__main__':
unittest.main()
12 changes: 5 additions & 7 deletions torchvision/datasets/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os.path

from .vision import VisionDataset
from .utils import download_and_extract_archive, makedir_exist_ok
from .utils import download_and_extract_archive, makedir_exist_ok, verify_str_arg


class Caltech101(VisionDataset):
Expand Down Expand Up @@ -32,10 +32,10 @@ def __init__(self, root, target_type="category", transform=None,
transform=transform,
target_transform=target_transform)
makedir_exist_ok(self.root)
if isinstance(target_type, list):
self.target_type = target_type
else:
self.target_type = [target_type]
if not isinstance(target_type, list):
target_type = [target_type]
self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation"))
for t in target_type]

if download:
self.download()
Expand Down Expand Up @@ -88,8 +88,6 @@ def __getitem__(self, index):
self.annotation_categories[self.y[index]],
"annotation_{:04d}.mat".format(self.index[index])))
target.append(data["obj_contour"])
else:
raise ValueError("Target type \"{}\" is not recognized.".format(t))
target = tuple(target) if len(target) > 1 else target[0]

if self.transform is not None:
Expand Down
22 changes: 10 additions & 12 deletions torchvision/datasets/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import PIL
from .vision import VisionDataset
from .utils import download_file_from_google_drive, check_integrity
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg


class CelebA(VisionDataset):
Expand Down Expand Up @@ -66,17 +66,14 @@ def __init__(self, root, split="train", target_type="attr", transform=None,
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

if split.lower() == "train":
split = 0
elif split.lower() == "valid":
split = 1
elif split.lower() == "test":
split = 2
elif split.lower() == "all":
split = None
else:
raise ValueError('Wrong split entered! Please use "train", '
'"valid", "test", or "all"')
split_map = {
"train": 0,
"valid": 1,
"test": 2,
"all": None,
}
split = split_map[verify_str_arg(split.lower(), "split",
("train", "valid", "test", "all"))]

fn = partial(os.path.join, self.root, self.base_folder)
splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
Expand Down Expand Up @@ -134,6 +131,7 @@ def __getitem__(self, index):
elif t == "landmarks":
target.append(self.landmarks_align[index, :])
else:
# TODO: refactor with utils.verify_str_arg
raise ValueError("Target type \"{}\" is not recognized.".format(t))
target = tuple(target) if len(target) > 1 else target[0]

Expand Down
27 changes: 13 additions & 14 deletions torchvision/datasets/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import namedtuple
import zipfile

from .utils import extract_archive
from .utils import extract_archive, verify_str_arg, iterable_to_str
from .vision import VisionDataset
from PIL import Image

Expand Down Expand Up @@ -109,22 +109,21 @@ def __init__(self, root, split='train', mode='fine', target_type='instance',
self.images = []
self.targets = []

if mode not in ['fine', 'coarse']:
raise ValueError('Invalid mode! Please use mode="fine" or mode="coarse"')

if mode == 'fine' and split not in ['train', 'test', 'val']:
raise ValueError('Invalid split for mode "fine"! Please use split="train", split="test"'
' or split="val"')
elif mode == 'coarse' and split not in ['train', 'train_extra', 'val']:
raise ValueError('Invalid split for mode "coarse"! Please use split="train", split="train_extra"'
' or split="val"')
verify_str_arg(mode, "mode", ("fine", "coarse"))
if mode == "fine":
valid_modes = ("train", "test", "val")
else:
valid_modes = ("train", "train_extra", "val")
msg = ("Unknown value '{}' for argument split if mode is '{}'. "
"Valid values are {{{}}}.")
msg = msg.format(split, mode, iterable_to_str(valid_modes))
verify_str_arg(split, "split", valid_modes, msg)

if not isinstance(target_type, list):
self.target_type = [target_type]

if not all(t in ['instance', 'semantic', 'polygon', 'color'] for t in self.target_type):
raise ValueError('Invalid value for "target_type"! Valid values are: "instance", "semantic", "polygon"'
' or "color"')
[verify_str_arg(value, "target_type",
("instance", "semantic", "polygon", "color"))
for value in self.target_type]

if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):

Expand Down
16 changes: 3 additions & 13 deletions torchvision/datasets/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import tempfile
import torch
from .folder import ImageFolder
from .utils import check_integrity, download_and_extract_archive, extract_archive
from .utils import check_integrity, download_and_extract_archive, extract_archive, \
verify_str_arg

ARCHIVE_DICT = {
'train': {
Expand Down Expand Up @@ -48,7 +49,7 @@ class ImageNet(ImageFolder):

def __init__(self, root, split='train', download=False, **kwargs):
root = self.root = os.path.expanduser(root)
self.split = self._verify_split(split)
self.split = verify_str_arg(split, "split", ("train", "val"))

if download:
self.download()
Expand Down Expand Up @@ -109,17 +110,6 @@ def _load_meta_file(self):
def _save_meta_file(self, wnid_to_class, val_wnids):
torch.save((wnid_to_class, val_wnids), self.meta_file)

def _verify_split(self, split):
if split not in self.valid_splits:
msg = "Unknown split {} .".format(split)
msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits))
raise ValueError(msg)
return split

@property
def valid_splits(self):
return 'train', 'val'

@property
def split_folder(self):
return os.path.join(self.root, self.split)
Expand Down
36 changes: 21 additions & 15 deletions torchvision/datasets/lsun.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
else:
import pickle

from .utils import verify_str_arg, iterable_to_str


class LSUNClass(VisionDataset):
def __init__(self, root, transform=None, target_transform=None):
Expand Down Expand Up @@ -75,27 +77,31 @@ def __init__(self, root, classes='train', transform=None, target_transform=None)
'living_room', 'restaurant', 'tower']
dset_opts = ['train', 'val', 'test']

if type(classes) == str and classes in dset_opts:
try:
verify_str_arg(classes, "classes", dset_opts)
if classes == 'test':
classes = [classes]
else:
classes = [c + '_' + classes for c in categories]
elif type(classes) == list:
except ValueError:
# TODO: Should this check for Iterable instead of list?
if not isinstance(classes, list):
raise ValueError
for c in classes:
# TODO: This assumes each item is a str (or subclass). Should this
# also be checked?
c_short = c.split('_')
c_short.pop(len(c_short) - 1)
c_short = '_'.join(c_short)
if c_short not in categories:
raise (ValueError('Unknown LSUN class: ' + c_short + '.'
'Options are: ' + str(categories)))
c_short = c.split('_')
c_short = c_short.pop(len(c_short) - 1)
if c_short not in dset_opts:
raise (ValueError('Unknown postfix: ' + c_short + '.'
'Options are: ' + str(dset_opts)))
else:
raise (ValueError('Unknown option for classes'))
self.classes = classes
category, dset_opt = '_'.join(c_short[:-1]), c_short[-1]
msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."

msg = msg_fmtstr.format(category, "LSUN class",
iterable_to_str(categories))
verify_str_arg(category, valid_values=categories, custom_msg=msg)

msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
finally:
self.classes = classes

# for each class, create an LSUNClassDataset
self.dbs = []
Expand Down
14 changes: 4 additions & 10 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import numpy as np
import torch
import codecs
from .utils import download_url, download_and_extract_archive, extract_archive, makedir_exist_ok
from .utils import download_url, download_and_extract_archive, extract_archive, \
makedir_exist_ok, verify_str_arg


class MNIST(VisionDataset):
Expand Down Expand Up @@ -230,11 +231,7 @@ class EMNIST(MNIST):
splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')

def __init__(self, root, split, **kwargs):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
self.split = split
self.split = verify_str_arg(split, "split", self.splits)
self.training_file = self._training_file(split)
self.test_file = self._test_file(split)
super(EMNIST, self).__init__(root, **kwargs)
Expand Down Expand Up @@ -336,10 +333,7 @@ class QMNIST(MNIST):
def __init__(self, root, what=None, compat=True, train=True, **kwargs):
if what is None:
what = 'train' if train else 'test'
if not self.subsets.get(what):
raise RuntimeError("Argument 'what' should be one of: \n " +
repr(tuple(self.subsets.keys())))
self.what = what
self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
self.compat = compat
self.data_file = what + '.pt'
self.training_file = self.data_file
Expand Down
16 changes: 4 additions & 12 deletions torchvision/datasets/sbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from PIL import Image
from .utils import download_url
from .utils import download_url, verify_str_arg
from .voc import download_extract


Expand Down Expand Up @@ -64,12 +64,9 @@ def __init__(self,
"pip install scipy")

super(SBDataset, self).__init__(root, transforms)

if mode not in ("segmentation", "boundaries"):
raise ValueError("Argument mode should be 'segmentation' or 'boundaries'")

self.image_set = image_set
self.mode = mode
self.image_set = verify_str_arg(image_set, "image_set",
("train", "val", "train_noval"))
self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries"))
self.num_classes = 20

sbd_root = self.root
Expand All @@ -91,11 +88,6 @@ def __init__(self,

split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt')

if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="val" or image_set="train_noval"')

with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]

Expand Down
10 changes: 4 additions & 6 deletions torchvision/datasets/stl10.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive
from .utils import check_integrity, download_and_extract_archive, verify_str_arg


class STL10(VisionDataset):
Expand Down Expand Up @@ -48,13 +48,9 @@ class STL10(VisionDataset):

def __init__(self, root, split='train', folds=None, transform=None,
target_transform=None, download=False):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
super(STL10, self).__init__(root, transform=transform,
target_transform=target_transform)
self.split = split # train/test/unlabeled set
self.split = verify_str_arg(split, "split", self.splits)
self.folds = folds # one of the 10 pre-defined folds or the full dataset

if download:
Expand Down Expand Up @@ -167,4 +163,6 @@ def __load_folds(self, folds):
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ')
self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx]
else:
# FIXME: docstring allows None for folds (it is even the default value)
# Is this intended?
raise ValueError('Folds "{}" not found. Valid splits are: 0-9.'.format(folds))
9 changes: 2 additions & 7 deletions torchvision/datasets/svhn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import os.path
import numpy as np
from .utils import download_url, check_integrity
from .utils import download_url, check_integrity, verify_str_arg


class SVHN(VisionDataset):
Expand Down Expand Up @@ -43,12 +43,7 @@ def __init__(self, root, split='train', transform=None, target_transform=None,
download=False):
super(SVHN, self).__init__(root, transform=transform,
target_transform=target_transform)
self.split = split # training set or test set or extra set

if self.split not in self.split_list:
raise ValueError('Wrong split entered! Please use split="train" '
'or split="extra" or split="test"')

self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
self.url = self.split_list[split][0]
self.filename = self.split_list[split][1]
self.file_md5 = self.split_list[split][2]
Expand Down
30 changes: 30 additions & 0 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tarfile
import zipfile

import torch
from torch.utils.model_zoo import tqdm


Expand Down Expand Up @@ -249,3 +250,32 @@ def download_and_extract_archive(url, download_root, extract_root=None, filename
archive = os.path.join(download_root, filename)
print("Extracting {} to {}".format(archive, extract_root))
extract_archive(archive, extract_root, remove_finished)


def iterable_to_str(iterable):
return "'" + "', '".join([str(item) for item in iterable]) + "'"


def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None):
if not isinstance(value, torch._six.string_classes):
if arg is None:
msg = "Expected type str, but got type {type}."
else:
msg = "Expected type str for argument {arg}, but got type {type}."
msg = msg.format(type=type(value), arg=arg)
raise ValueError(msg)

if valid_values is None:
return value

if value not in valid_values:
if custom_msg is not None:
msg = custom_msg
else:
msg = ("Unknown value '{value}' for argument {arg}. "
"Valid values are {{{valid_values}}}.")
msg = msg.format(value=value, arg=arg,
valid_values=iterable_to_str(valid_values))
raise ValueError(msg)

return value
Loading

0 comments on commit 4886ccc

Please sign in to comment.