From ffaa9ab44f03fc739bbab40ceb0eeefead269ae7 Mon Sep 17 00:00:00 2001 From: "ernest.parke" Date: Mon, 4 Jun 2018 13:42:26 -0400 Subject: [PATCH] Addresses #145, based off of @fmassa --- torchvision/datasets/folder.py | 73 ++++++++++++++++++++++------------ 1 file changed, 48 insertions(+), 25 deletions(-) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 4352b120d07..4ec5a325f03 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -32,29 +32,6 @@ def is_image_file(filename): return has_file_allowed_extension(filename, IMG_EXTENSIONS) -def find_classes(dir): - classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] - classes.sort() - class_to_idx = {classes[i]: i for i in range(len(classes))} - return classes, class_to_idx - - -def make_dataset(dir, class_to_idx, extensions): - images = [] - dir = os.path.expanduser(dir) - for target in sorted(os.listdir(dir)): - d = os.path.join(dir, target) - if not os.path.isdir(d): - continue - - for root, _, fnames in sorted(os.walk(d)): - for fname in sorted(fnames): - if has_file_allowed_extension(fname, extensions): - path = os.path.join(root, fname) - item = (path, class_to_idx[target]) - images.append(item) - - return images class DatasetFolder(data.Dataset): @@ -86,8 +63,8 @@ class DatasetFolder(data.Dataset): """ def __init__(self, root, loader, extensions, transform=None, target_transform=None): - classes, class_to_idx = find_classes(root) - samples = make_dataset(root, class_to_idx, extensions) + classes, class_to_idx = self._find_classes(root) + samples = self._make_dataset(root, class_to_idx, extensions) if len(samples) == 0: raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" "Supported extensions are: " + ",".join(extensions))) @@ -104,6 +81,52 @@ def __init__(self, root, loader, extensions, transform=None, target_transform=No self.transform = transform self.target_transform = target_transform + def _find_classes(dir): + """ + Finds the classes in a dataset directory. + + Args: + dir (string): Root directory path. + + Returns: + tuple: (classes, class_to_idx) where class_to_idx is a dictionary + """ + classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] + classes.sort() + class_to_idx = {classes[i]: i for i in range(len(classes))} + return classes, class_to_idx + + + def _make_dataset(dir, class_to_idx, extensions): + """ + A generic method for obtaining paths to all data files. + + Args: + dir (string): Root directory path. + class_to_idx (dictionary): A mapping of class names to id's. + extensions (list): A list of permitted data file extensions. + + Returns: + images: A list of (path, target) per data file. + + """ + images = [] + dir = os.path.expanduser(dir) + for target in sorted(os.listdir(dir)): + d = os.path.join(dir, target) + if not os.path.isdir(d): + continue + + for root, _, fnames in sorted(os.walk(d)): + for fname in sorted(fnames): + if has_file_allowed_extension(fname, extensions): + path = os.path.join(root, fname) + item = (path, class_to_idx[target]) + images.append(item) + + return images + + def __getitem__(self, index): """ Args: