Skip to content

Commit ffaa9ab

Browse files
committed
Addresses pytorch#145, based off of @fmassa
1 parent 5a0d079 commit ffaa9ab

File tree

1 file changed

+48
-25
lines changed

1 file changed

+48
-25
lines changed

torchvision/datasets/folder.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,29 +32,6 @@ def is_image_file(filename):
3232
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
3333

3434

35-
def find_classes(dir):
36-
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
37-
classes.sort()
38-
class_to_idx = {classes[i]: i for i in range(len(classes))}
39-
return classes, class_to_idx
40-
41-
42-
def make_dataset(dir, class_to_idx, extensions):
43-
images = []
44-
dir = os.path.expanduser(dir)
45-
for target in sorted(os.listdir(dir)):
46-
d = os.path.join(dir, target)
47-
if not os.path.isdir(d):
48-
continue
49-
50-
for root, _, fnames in sorted(os.walk(d)):
51-
for fname in sorted(fnames):
52-
if has_file_allowed_extension(fname, extensions):
53-
path = os.path.join(root, fname)
54-
item = (path, class_to_idx[target])
55-
images.append(item)
56-
57-
return images
5835

5936

6037
class DatasetFolder(data.Dataset):
@@ -86,8 +63,8 @@ class DatasetFolder(data.Dataset):
8663
"""
8764

8865
def __init__(self, root, loader, extensions, transform=None, target_transform=None):
89-
classes, class_to_idx = find_classes(root)
90-
samples = make_dataset(root, class_to_idx, extensions)
66+
classes, class_to_idx = self._find_classes(root)
67+
samples = self._make_dataset(root, class_to_idx, extensions)
9168
if len(samples) == 0:
9269
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
9370
"Supported extensions are: " + ",".join(extensions)))
@@ -104,6 +81,52 @@ def __init__(self, root, loader, extensions, transform=None, target_transform=No
10481
self.transform = transform
10582
self.target_transform = target_transform
10683

84+
def _find_classes(dir):
85+
"""
86+
Finds the classes in a dataset directory.
87+
88+
Args:
89+
dir (string): Root directory path.
90+
91+
Returns:
92+
tuple: (classes, class_to_idx) where class_to_idx is a dictionary
93+
"""
94+
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
95+
classes.sort()
96+
class_to_idx = {classes[i]: i for i in range(len(classes))}
97+
return classes, class_to_idx
98+
99+
100+
def _make_dataset(dir, class_to_idx, extensions):
101+
"""
102+
A generic method for obtaining paths to all data files.
103+
104+
Args:
105+
dir (string): Root directory path.
106+
class_to_idx (dictionary): A mapping of class names to id's.
107+
extensions (list): A list of permitted data file extensions.
108+
109+
Returns:
110+
images: A list of (path, target) per data file.
111+
112+
"""
113+
images = []
114+
dir = os.path.expanduser(dir)
115+
for target in sorted(os.listdir(dir)):
116+
d = os.path.join(dir, target)
117+
if not os.path.isdir(d):
118+
continue
119+
120+
for root, _, fnames in sorted(os.walk(d)):
121+
for fname in sorted(fnames):
122+
if has_file_allowed_extension(fname, extensions):
123+
path = os.path.join(root, fname)
124+
item = (path, class_to_idx[target])
125+
images.append(item)
126+
127+
return images
128+
129+
107130
def __getitem__(self, index):
108131
"""
109132
Args:

0 commit comments

Comments
 (0)