diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 7fceb5f25e6..1940c9e8493 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -107,6 +107,11 @@ def test_extract_gzip(self): data = nf.read() self.assertEqual(data, 'this is the content') + def test_verify_str_arg(self): + self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",))) + self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg") + self.assertRaises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") + if __name__ == '__main__': unittest.main() diff --git a/torchvision/datasets/caltech.py b/torchvision/datasets/caltech.py index 6c14500f4f3..e18349d76a0 100644 --- a/torchvision/datasets/caltech.py +++ b/torchvision/datasets/caltech.py @@ -4,7 +4,7 @@ import os.path from .vision import VisionDataset -from .utils import download_and_extract_archive, makedir_exist_ok +from .utils import download_and_extract_archive, makedir_exist_ok, verify_str_arg class Caltech101(VisionDataset): @@ -32,10 +32,10 @@ def __init__(self, root, target_type="category", transform=None, transform=transform, target_transform=target_transform) makedir_exist_ok(self.root) - if isinstance(target_type, list): - self.target_type = target_type - else: - self.target_type = [target_type] + if not isinstance(target_type, list): + target_type = [target_type] + self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) + for t in target_type] if download: self.download() @@ -88,8 +88,6 @@ def __getitem__(self, index): self.annotation_categories[self.y[index]], "annotation_{:04d}.mat".format(self.index[index]))) target.append(data["obj_contour"]) - else: - raise ValueError("Target type \"{}\" is not recognized.".format(t)) target = tuple(target) if len(target) > 1 else target[0] if self.transform is not None: diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index 88033f3f375..8d593218e46 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -3,7 +3,7 @@ import os import PIL from .vision import VisionDataset -from .utils import download_file_from_google_drive, check_integrity +from .utils import download_file_from_google_drive, check_integrity, verify_str_arg class CelebA(VisionDataset): @@ -66,17 +66,14 @@ def __init__(self, root, split="train", target_type="attr", transform=None, raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') - if split.lower() == "train": - split = 0 - elif split.lower() == "valid": - split = 1 - elif split.lower() == "test": - split = 2 - elif split.lower() == "all": - split = None - else: - raise ValueError('Wrong split entered! Please use "train", ' - '"valid", "test", or "all"') + split_map = { + "train": 0, + "valid": 1, + "test": 2, + "all": None, + } + split = split_map[verify_str_arg(split.lower(), "split", + ("train", "valid", "test", "all"))] fn = partial(os.path.join, self.root, self.base_folder) splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0) @@ -134,6 +131,7 @@ def __getitem__(self, index): elif t == "landmarks": target.append(self.landmarks_align[index, :]) else: + # TODO: refactor with utils.verify_str_arg raise ValueError("Target type \"{}\" is not recognized.".format(t)) target = tuple(target) if len(target) > 1 else target[0] diff --git a/torchvision/datasets/cityscapes.py b/torchvision/datasets/cityscapes.py index cfa7f0043b3..56ff20bc3f1 100644 --- a/torchvision/datasets/cityscapes.py +++ b/torchvision/datasets/cityscapes.py @@ -3,7 +3,7 @@ from collections import namedtuple import zipfile -from .utils import extract_archive +from .utils import extract_archive, verify_str_arg, iterable_to_str from .vision import VisionDataset from PIL import Image @@ -109,22 +109,21 @@ def __init__(self, root, split='train', mode='fine', target_type='instance', self.images = [] self.targets = [] - if mode not in ['fine', 'coarse']: - raise ValueError('Invalid mode! Please use mode="fine" or mode="coarse"') - - if mode == 'fine' and split not in ['train', 'test', 'val']: - raise ValueError('Invalid split for mode "fine"! Please use split="train", split="test"' - ' or split="val"') - elif mode == 'coarse' and split not in ['train', 'train_extra', 'val']: - raise ValueError('Invalid split for mode "coarse"! Please use split="train", split="train_extra"' - ' or split="val"') + verify_str_arg(mode, "mode", ("fine", "coarse")) + if mode == "fine": + valid_modes = ("train", "test", "val") + else: + valid_modes = ("train", "train_extra", "val") + msg = ("Unknown value '{}' for argument split if mode is '{}'. " + "Valid values are {{{}}}.") + msg = msg.format(split, mode, iterable_to_str(valid_modes)) + verify_str_arg(split, "split", valid_modes, msg) if not isinstance(target_type, list): self.target_type = [target_type] - - if not all(t in ['instance', 'semantic', 'polygon', 'color'] for t in self.target_type): - raise ValueError('Invalid value for "target_type"! Valid values are: "instance", "semantic", "polygon"' - ' or "color"') + [verify_str_arg(value, "target_type", + ("instance", "semantic", "polygon", "color")) + for value in self.target_type] if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index d6eead13a35..14a256c66ab 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -4,7 +4,8 @@ import tempfile import torch from .folder import ImageFolder -from .utils import check_integrity, download_and_extract_archive, extract_archive +from .utils import check_integrity, download_and_extract_archive, extract_archive, \ + verify_str_arg ARCHIVE_DICT = { 'train': { @@ -48,7 +49,7 @@ class ImageNet(ImageFolder): def __init__(self, root, split='train', download=False, **kwargs): root = self.root = os.path.expanduser(root) - self.split = self._verify_split(split) + self.split = verify_str_arg(split, "split", ("train", "val")) if download: self.download() @@ -109,17 +110,6 @@ def _load_meta_file(self): def _save_meta_file(self, wnid_to_class, val_wnids): torch.save((wnid_to_class, val_wnids), self.meta_file) - def _verify_split(self, split): - if split not in self.valid_splits: - msg = "Unknown split {} .".format(split) - msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits)) - raise ValueError(msg) - return split - - @property - def valid_splits(self): - return 'train', 'val' - @property def split_folder(self): return os.path.join(self.root, self.split) diff --git a/torchvision/datasets/lsun.py b/torchvision/datasets/lsun.py index aec65722cd9..66234a35aeb 100644 --- a/torchvision/datasets/lsun.py +++ b/torchvision/datasets/lsun.py @@ -11,6 +11,8 @@ else: import pickle +from .utils import verify_str_arg, iterable_to_str + class LSUNClass(VisionDataset): def __init__(self, root, transform=None, target_transform=None): @@ -75,27 +77,31 @@ def __init__(self, root, classes='train', transform=None, target_transform=None) 'living_room', 'restaurant', 'tower'] dset_opts = ['train', 'val', 'test'] - if type(classes) == str and classes in dset_opts: + try: + verify_str_arg(classes, "classes", dset_opts) if classes == 'test': classes = [classes] else: classes = [c + '_' + classes for c in categories] - elif type(classes) == list: + except ValueError: + # TODO: Should this check for Iterable instead of list? + if not isinstance(classes, list): + raise ValueError for c in classes: + # TODO: This assumes each item is a str (or subclass). Should this + # also be checked? c_short = c.split('_') - c_short.pop(len(c_short) - 1) - c_short = '_'.join(c_short) - if c_short not in categories: - raise (ValueError('Unknown LSUN class: ' + c_short + '.' - 'Options are: ' + str(categories))) - c_short = c.split('_') - c_short = c_short.pop(len(c_short) - 1) - if c_short not in dset_opts: - raise (ValueError('Unknown postfix: ' + c_short + '.' - 'Options are: ' + str(dset_opts))) - else: - raise (ValueError('Unknown option for classes')) - self.classes = classes + category, dset_opt = '_'.join(c_short[:-1]), c_short[-1] + msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}." + + msg = msg_fmtstr.format(category, "LSUN class", + iterable_to_str(categories)) + verify_str_arg(category, valid_values=categories, custom_msg=msg) + + msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts)) + verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg) + finally: + self.classes = classes # for each class, create an LSUNClassDataset self.dbs = [] diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 4847cfdcaa9..2109decff33 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -7,7 +7,8 @@ import numpy as np import torch import codecs -from .utils import download_url, download_and_extract_archive, extract_archive, makedir_exist_ok +from .utils import download_url, download_and_extract_archive, extract_archive, \ + makedir_exist_ok, verify_str_arg class MNIST(VisionDataset): @@ -230,11 +231,7 @@ class EMNIST(MNIST): splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist') def __init__(self, root, split, **kwargs): - if split not in self.splits: - raise ValueError('Split "{}" not found. Valid splits are: {}'.format( - split, ', '.join(self.splits), - )) - self.split = split + self.split = verify_str_arg(split, "split", self.splits) self.training_file = self._training_file(split) self.test_file = self._test_file(split) super(EMNIST, self).__init__(root, **kwargs) @@ -336,10 +333,7 @@ class QMNIST(MNIST): def __init__(self, root, what=None, compat=True, train=True, **kwargs): if what is None: what = 'train' if train else 'test' - if not self.subsets.get(what): - raise RuntimeError("Argument 'what' should be one of: \n " + - repr(tuple(self.subsets.keys()))) - self.what = what + self.what = verify_str_arg(what, "what", tuple(self.subsets.keys())) self.compat = compat self.data_file = what + '.pt' self.training_file = self.data_file diff --git a/torchvision/datasets/sbd.py b/torchvision/datasets/sbd.py index 3c9202fdbfd..c4713f72576 100644 --- a/torchvision/datasets/sbd.py +++ b/torchvision/datasets/sbd.py @@ -5,7 +5,7 @@ import numpy as np from PIL import Image -from .utils import download_url +from .utils import download_url, verify_str_arg from .voc import download_extract @@ -64,12 +64,9 @@ def __init__(self, "pip install scipy") super(SBDataset, self).__init__(root, transforms) - - if mode not in ("segmentation", "boundaries"): - raise ValueError("Argument mode should be 'segmentation' or 'boundaries'") - - self.image_set = image_set - self.mode = mode + self.image_set = verify_str_arg(image_set, "image_set", + ("train", "val", "train_noval")) + self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries")) self.num_classes = 20 sbd_root = self.root @@ -91,11 +88,6 @@ def __init__(self, split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt') - if not os.path.exists(split_f): - raise ValueError( - 'Wrong image_set entered! Please use image_set="train" ' - 'or image_set="val" or image_set="train_noval"') - with open(os.path.join(split_f), "r") as f: file_names = [x.strip() for x in f.readlines()] diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index 7c49877bfcc..863dbf0fbf4 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -5,7 +5,7 @@ import numpy as np from .vision import VisionDataset -from .utils import check_integrity, download_and_extract_archive +from .utils import check_integrity, download_and_extract_archive, verify_str_arg class STL10(VisionDataset): @@ -48,13 +48,9 @@ class STL10(VisionDataset): def __init__(self, root, split='train', folds=None, transform=None, target_transform=None, download=False): - if split not in self.splits: - raise ValueError('Split "{}" not found. Valid splits are: {}'.format( - split, ', '.join(self.splits), - )) super(STL10, self).__init__(root, transform=transform, target_transform=target_transform) - self.split = split # train/test/unlabeled set + self.split = verify_str_arg(split, "split", self.splits) self.folds = folds # one of the 10 pre-defined folds or the full dataset if download: @@ -167,4 +163,6 @@ def __load_folds(self, folds): list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ') self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx] else: + # FIXME: docstring allows None for folds (it is even the default value) + # Is this intended? raise ValueError('Folds "{}" not found. Valid splits are: 0-9.'.format(folds)) diff --git a/torchvision/datasets/svhn.py b/torchvision/datasets/svhn.py index 0ceb1099843..4d8ed990bdb 100644 --- a/torchvision/datasets/svhn.py +++ b/torchvision/datasets/svhn.py @@ -4,7 +4,7 @@ import os import os.path import numpy as np -from .utils import download_url, check_integrity +from .utils import download_url, check_integrity, verify_str_arg class SVHN(VisionDataset): @@ -43,12 +43,7 @@ def __init__(self, root, split='train', transform=None, target_transform=None, download=False): super(SVHN, self).__init__(root, transform=transform, target_transform=target_transform) - self.split = split # training set or test set or extra set - - if self.split not in self.split_list: - raise ValueError('Wrong split entered! Please use split="train" ' - 'or split="extra" or split="test"') - + self.split = verify_str_arg(split, "split", tuple(self.split_list.keys())) self.url = self.split_list[split][0] self.filename = self.split_list[split][1] self.file_md5 = self.split_list[split][2] diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 27315e6e912..937bed90399 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -6,6 +6,7 @@ import tarfile import zipfile +import torch from torch.utils.model_zoo import tqdm @@ -249,3 +250,32 @@ def download_and_extract_archive(url, download_root, extract_root=None, filename archive = os.path.join(download_root, filename) print("Extracting {} to {}".format(archive, extract_root)) extract_archive(archive, extract_root, remove_finished) + + +def iterable_to_str(iterable): + return "'" + "', '".join([str(item) for item in iterable]) + "'" + + +def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None): + if not isinstance(value, torch._six.string_classes): + if arg is None: + msg = "Expected type str, but got type {type}." + else: + msg = "Expected type str for argument {arg}, but got type {type}." + msg = msg.format(type=type(value), arg=arg) + raise ValueError(msg) + + if valid_values is None: + return value + + if value not in valid_values: + if custom_msg is not None: + msg = custom_msg + else: + msg = ("Unknown value '{value}' for argument {arg}. " + "Valid values are {{{valid_values}}}.") + msg = msg.format(value=value, arg=arg, + valid_values=iterable_to_str(valid_values)) + raise ValueError(msg) + + return value diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py index 390c441911d..8a6925011ba 100644 --- a/torchvision/datasets/voc.py +++ b/torchvision/datasets/voc.py @@ -10,7 +10,7 @@ import xml.etree.ElementTree as ET from PIL import Image -from .utils import download_url, check_integrity +from .utils import download_url, check_integrity, verify_str_arg DATASET_YEAR_DICT = { '2012': { @@ -83,7 +83,8 @@ def __init__(self, self.url = DATASET_YEAR_DICT[year]['url'] self.filename = DATASET_YEAR_DICT[year]['filename'] self.md5 = DATASET_YEAR_DICT[year]['md5'] - self.image_set = image_set + self.image_set = verify_str_arg(image_set, "image_set", + ("train", "trainval", "val")) base_dir = DATASET_YEAR_DICT[year]['base_dir'] voc_root = os.path.join(self.root, base_dir) image_dir = os.path.join(voc_root, 'JPEGImages') @@ -100,11 +101,6 @@ def __init__(self, split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') - if not os.path.exists(split_f): - raise ValueError( - 'Wrong image_set entered! Please use image_set="train" ' - 'or image_set="trainval" or image_set="val"') - with open(os.path.join(split_f), "r") as f: file_names = [x.strip() for x in f.readlines()] @@ -164,7 +160,8 @@ def __init__(self, self.url = DATASET_YEAR_DICT[year]['url'] self.filename = DATASET_YEAR_DICT[year]['filename'] self.md5 = DATASET_YEAR_DICT[year]['md5'] - self.image_set = image_set + self.image_set = verify_str_arg(image_set, "image_set", + ("train", "trainval", "val")) base_dir = DATASET_YEAR_DICT[year]['base_dir'] voc_root = os.path.join(self.root, base_dir) @@ -182,12 +179,6 @@ def __init__(self, split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') - if not os.path.exists(split_f): - raise ValueError( - 'Wrong image_set entered! Please use image_set="train" ' - 'or image_set="trainval" or image_set="val" or a valid' - 'image_set from the VOC ImageSets/Main folder.') - with open(os.path.join(split_f), "r") as f: file_names = [x.strip() for x in f.readlines()]