diff --git a/test/fakedata_generation.py b/test/fakedata_generation.py index bef63a37b05..d14bc0c8304 100644 --- a/test/fakedata_generation.py +++ b/test/fakedata_generation.py @@ -241,3 +241,20 @@ def _make_polygon_target(file): '{city}_000000_000000_leftImg8bit.png'.format(city=city))) yield tmp_dir + + +@contextlib.contextmanager +def svhn_root(): + import scipy.io as sio + + def _make_mat(file): + images = np.zeros((32, 32, 3, 2), dtype=np.uint8) + targets = np.zeros((2,), dtype=np.uint8) + sio.savemat(file, {'X': images, 'y': targets}) + + with get_tmp_dir() as root: + _make_mat(os.path.join(root, "train_32x32.mat")) + _make_mat(os.path.join(root, "test_32x32.mat")) + _make_mat(os.path.join(root, "extra_32x32.mat")) + + yield root diff --git a/test/test_datasets.py b/test/test_datasets.py index 19914b5e2d1..0c96df11e4e 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -7,7 +7,8 @@ from torch._utils_internal import get_file_path_2 import torchvision from common_utils import get_tmp_dir -from fakedata_generation import mnist_root, cifar_root, imagenet_root, cityscapes_root +from fakedata_generation import mnist_root, cifar_root, imagenet_root, \ + cityscapes_root, svhn_root class Tester(unittest.TestCase): @@ -185,6 +186,19 @@ def test_cityscapes(self): self.assertTrue(isinstance(output[1][1], dict)) # polygon self.assertTrue(isinstance(output[1][2], PIL.Image.Image)) # color + @mock.patch('torchvision.datasets.SVHN._check_integrity') + def test_svhn(self, mock_check): + mock_check.return_value = True + with svhn_root() as root: + dataset = torchvision.datasets.SVHN(root, split="train") + self.generic_classification_dataset_test(dataset, num_images=2) + + dataset = torchvision.datasets.SVHN(root, split="test") + self.generic_classification_dataset_test(dataset, num_images=2) + + dataset = torchvision.datasets.SVHN(root, split="extra") + self.generic_classification_dataset_test(dataset, num_images=2) + if __name__ == '__main__': unittest.main()