@@ -32,29 +32,6 @@ def is_image_file(filename):
32
32
return has_file_allowed_extension (filename , IMG_EXTENSIONS )
33
33
34
34
35
- def find_classes (dir ):
36
- classes = [d for d in os .listdir (dir ) if os .path .isdir (os .path .join (dir , d ))]
37
- classes .sort ()
38
- class_to_idx = {classes [i ]: i for i in range (len (classes ))}
39
- return classes , class_to_idx
40
-
41
-
42
- def make_dataset (dir , class_to_idx , extensions ):
43
- images = []
44
- dir = os .path .expanduser (dir )
45
- for target in sorted (os .listdir (dir )):
46
- d = os .path .join (dir , target )
47
- if not os .path .isdir (d ):
48
- continue
49
-
50
- for root , _ , fnames in sorted (os .walk (d )):
51
- for fname in sorted (fnames ):
52
- if has_file_allowed_extension (fname , extensions ):
53
- path = os .path .join (root , fname )
54
- item = (path , class_to_idx [target ])
55
- images .append (item )
56
-
57
- return images
58
35
59
36
60
37
class DatasetFolder (data .Dataset ):
@@ -86,8 +63,8 @@ class DatasetFolder(data.Dataset):
86
63
"""
87
64
88
65
def __init__ (self , root , loader , extensions , transform = None , target_transform = None ):
89
- classes , class_to_idx = find_classes (root )
90
- samples = make_dataset (root , class_to_idx , extensions )
66
+ classes , class_to_idx = self . _find_classes (root )
67
+ samples = self . _make_dataset (root , class_to_idx , extensions )
91
68
if len (samples ) == 0 :
92
69
raise (RuntimeError ("Found 0 files in subfolders of: " + root + "\n "
93
70
"Supported extensions are: " + "," .join (extensions )))
@@ -104,6 +81,52 @@ def __init__(self, root, loader, extensions, transform=None, target_transform=No
104
81
self .transform = transform
105
82
self .target_transform = target_transform
106
83
84
+ def _find_classes (dir ):
85
+ """
86
+ Finds the classes in a dataset directory.
87
+
88
+ Args:
89
+ dir (string): Root directory path.
90
+
91
+ Returns:
92
+ tuple: (classes, class_to_idx) where class_to_idx is a dictionary
93
+ """
94
+ classes = [d for d in os .listdir (dir ) if os .path .isdir (os .path .join (dir , d ))]
95
+ classes .sort ()
96
+ class_to_idx = {classes [i ]: i for i in range (len (classes ))}
97
+ return classes , class_to_idx
98
+
99
+
100
+ def _make_dataset (dir , class_to_idx , extensions ):
101
+ """
102
+ A generic method for obtaining paths to all data files.
103
+
104
+ Args:
105
+ dir (string): Root directory path.
106
+ class_to_idx (dictionary): A mapping of class names to id's.
107
+ extensions (list): A list of permitted data file extensions.
108
+
109
+ Returns:
110
+ images: A list of (path, target) per data file.
111
+
112
+ """
113
+ images = []
114
+ dir = os .path .expanduser (dir )
115
+ for target in sorted (os .listdir (dir )):
116
+ d = os .path .join (dir , target )
117
+ if not os .path .isdir (d ):
118
+ continue
119
+
120
+ for root , _ , fnames in sorted (os .walk (d )):
121
+ for fname in sorted (fnames ):
122
+ if has_file_allowed_extension (fname , extensions ):
123
+ path = os .path .join (root , fname )
124
+ item = (path , class_to_idx [target ])
125
+ images .append (item )
126
+
127
+ return images
128
+
129
+
107
130
def __getitem__ (self , index ):
108
131
"""
109
132
Args:
0 commit comments