@@ -57,67 +57,6 @@ def generic_segmentation_dataset_test(self, dataset, num_images=1):
5757
5858
5959class Tester (DatasetTestcase ):
60- def test_imagefolder (self ):
61- # TODO: create the fake data on-the-fly
62- FAKEDATA_DIR = get_file_path_2 (
63- os .path .dirname (os .path .abspath (__file__ )), 'assets' , 'fakedata' )
64-
65- with get_tmp_dir (src = os .path .join (FAKEDATA_DIR , 'imagefolder' )) as root :
66- classes = sorted (['a' , 'b' ])
67- class_a_image_files = [
68- os .path .join (root , 'a' , file ) for file in ('a1.png' , 'a2.png' , 'a3.png' )
69- ]
70- class_b_image_files = [
71- os .path .join (root , 'b' , file ) for file in ('b1.png' , 'b2.png' , 'b3.png' , 'b4.png' )
72- ]
73- dataset = torchvision .datasets .ImageFolder (root , loader = lambda x : x )
74-
75- # test if all classes are present
76- self .assertEqual (classes , sorted (dataset .classes ))
77-
78- # test if combination of classes and class_to_index functions correctly
79- for cls in classes :
80- self .assertEqual (cls , dataset .classes [dataset .class_to_idx [cls ]])
81-
82- # test if all images were detected correctly
83- class_a_idx = dataset .class_to_idx ['a' ]
84- class_b_idx = dataset .class_to_idx ['b' ]
85- imgs_a = [(img_file , class_a_idx ) for img_file in class_a_image_files ]
86- imgs_b = [(img_file , class_b_idx ) for img_file in class_b_image_files ]
87- imgs = sorted (imgs_a + imgs_b )
88- self .assertEqual (imgs , dataset .imgs )
89-
90- # test if the datasets outputs all images correctly
91- outputs = sorted ([dataset [i ] for i in range (len (dataset ))])
92- self .assertEqual (imgs , outputs )
93-
94- # redo all tests with specified valid image files
95- dataset = torchvision .datasets .ImageFolder (
96- root , loader = lambda x : x , is_valid_file = lambda x : '3' in x )
97- self .assertEqual (classes , sorted (dataset .classes ))
98-
99- class_a_idx = dataset .class_to_idx ['a' ]
100- class_b_idx = dataset .class_to_idx ['b' ]
101- imgs_a = [(img_file , class_a_idx ) for img_file in class_a_image_files
102- if '3' in img_file ]
103- imgs_b = [(img_file , class_b_idx ) for img_file in class_b_image_files
104- if '3' in img_file ]
105- imgs = sorted (imgs_a + imgs_b )
106- self .assertEqual (imgs , dataset .imgs )
107-
108- outputs = sorted ([dataset [i ] for i in range (len (dataset ))])
109- self .assertEqual (imgs , outputs )
110-
111- def test_imagefolder_empty (self ):
112- with get_tmp_dir () as root :
113- with self .assertRaises (FileNotFoundError ):
114- torchvision .datasets .ImageFolder (root , loader = lambda x : x )
115-
116- with self .assertRaises (FileNotFoundError ):
117- torchvision .datasets .ImageFolder (
118- root , loader = lambda x : x , is_valid_file = lambda x : False
119- )
120-
12160 @mock .patch ('torchvision.datasets.SVHN._check_integrity' )
12261 @unittest .skipIf (not HAS_SCIPY , "scipy unavailable" )
12362 def test_svhn (self , mock_check ):
@@ -1673,5 +1612,95 @@ def test_num_examples_test50k(self):
16731612 self .assertEqual (len (dataset ), info ["num_examples" ] - 10000 )
16741613
16751614
1615+ class DatasetFolderTestCase (datasets_utils .ImageDatasetTestCase ):
1616+ DATASET_CLASS = datasets .DatasetFolder
1617+
1618+ # The dataset has no fixed return type since it is defined by the loader parameter. For testing, we use a loader
1619+ # that simply returns the path as type 'str' instead of loading anything. See the 'dataset_args()' method.
1620+ FEATURE_TYPES = (str , int )
1621+
1622+ _IMAGE_EXTENSIONS = ("jpg" , "png" )
1623+ _VIDEO_EXTENSIONS = ("avi" , "mp4" )
1624+ _EXTENSIONS = (* _IMAGE_EXTENSIONS , * _VIDEO_EXTENSIONS )
1625+
1626+ # DatasetFolder has two mutually exclusive parameters: 'extensions' and 'is_valid_file'. One of both is required.
1627+ # We only iterate over different 'extensions' here and handle the tests for 'is_valid_file' in the
1628+ # 'test_is_valid_file()' method.
1629+ DEFAULT_CONFIG = dict (extensions = _EXTENSIONS )
1630+ ADDITIONAL_CONFIGS = (
1631+ * datasets_utils .combinations_grid (extensions = [(ext ,) for ext in _IMAGE_EXTENSIONS ]),
1632+ dict (extensions = _IMAGE_EXTENSIONS ),
1633+ * datasets_utils .combinations_grid (extensions = [(ext ,) for ext in _VIDEO_EXTENSIONS ]),
1634+ dict (extensions = _VIDEO_EXTENSIONS ),
1635+ )
1636+
1637+ def dataset_args (self , tmpdir , config ):
1638+ return tmpdir , lambda x : x
1639+
1640+ def inject_fake_data (self , tmpdir , config ):
1641+ extensions = config ["extensions" ] or self ._is_valid_file_to_extensions (config ["is_valid_file" ])
1642+
1643+ num_examples_total = 0
1644+ classes = []
1645+ for ext , cls in zip (self ._EXTENSIONS , string .ascii_letters ):
1646+ if ext not in extensions :
1647+ continue
1648+
1649+ create_example_folder = (
1650+ datasets_utils .create_image_folder
1651+ if ext in self ._IMAGE_EXTENSIONS
1652+ else datasets_utils .create_video_folder
1653+ )
1654+
1655+ num_examples = torch .randint (1 , 3 , size = ()).item ()
1656+ create_example_folder (tmpdir , cls , lambda idx : self ._file_name_fn (cls , ext , idx ), num_examples )
1657+
1658+ num_examples_total += num_examples
1659+ classes .append (cls )
1660+
1661+ return dict (num_examples = num_examples_total , classes = classes )
1662+
1663+ def _file_name_fn (self , cls , ext , idx ):
1664+ return f"{ cls } _{ idx } .{ ext } "
1665+
1666+ def _is_valid_file_to_extensions (self , is_valid_file ):
1667+ return {ext for ext in self ._EXTENSIONS if is_valid_file (f"foo.{ ext } " )}
1668+
1669+ @datasets_utils .test_all_configs
1670+ def test_is_valid_file (self , config ):
1671+ extensions = config .pop ("extensions" )
1672+ # We need to explicitly pass extensions=None here or otherwise it would be filled by the value from the
1673+ # DEFAULT_CONFIG.
1674+ with self .create_dataset (
1675+ config , extensions = None , is_valid_file = lambda file : pathlib .Path (file ).suffix [1 :] in extensions
1676+ ) as (dataset , info ):
1677+ self .assertEqual (len (dataset ), info ["num_examples" ])
1678+
1679+ @datasets_utils .test_all_configs
1680+ def test_classes (self , config ):
1681+ with self .create_dataset (config ) as (dataset , info ):
1682+ self .assertSequenceEqual (dataset .classes , info ["classes" ])
1683+
1684+
1685+ class ImageFolderTestCase (datasets_utils .ImageDatasetTestCase ):
1686+ DATASET_CLASS = datasets .ImageFolder
1687+
1688+ def inject_fake_data (self , tmpdir , config ):
1689+ num_examples_total = 0
1690+ classes = ("a" , "b" )
1691+ for cls in classes :
1692+ num_examples = torch .randint (1 , 3 , size = ()).item ()
1693+ num_examples_total += num_examples
1694+
1695+ datasets_utils .create_image_folder (tmpdir , cls , lambda idx : f"{ cls } _{ idx } .png" , num_examples )
1696+
1697+ return dict (num_examples = num_examples_total , classes = classes )
1698+
1699+ @datasets_utils .test_all_configs
1700+ def test_classes (self , config ):
1701+ with self .create_dataset (config ) as (dataset , info ):
1702+ self .assertSequenceEqual (dataset .classes , info ["classes" ])
1703+
1704+
16761705if __name__ == "__main__" :
16771706 unittest .main ()
0 commit comments