Skip to content

Read COCO dataset from ZIP file #950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 51 commits into from
Closed
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
e7f6f66
Read COCO dataset images from its zipfile when it is there
koenvandesande May 23, 2019
5aebeae
Also do it for CocoCaptions
koenvandesande May 23, 2019
7803588
Move code into utils.py, remove the magic constants and import them i…
koenvandesande May 24, 2019
3e5e03d
Add test for zip lookup class
koenvandesande May 24, 2019
4e39618
Fix for Python versions < 3.6
koenvandesande May 24, 2019
9a59666
Generalize to CelebA, move part of shared logic into VisionDataset
koenvandesande May 24, 2019
0bc30f2
Fix import
koenvandesande May 24, 2019
a8b483a
flake8 fixes
koenvandesande May 24, 2019
26d51d0
Simplify implementation of ZipLookup by not keeping file descriptor open
koenvandesande May 24, 2019
29d7df8
Remove unused import
koenvandesande May 24, 2019
46ceaf0
Support reading images from ZIP for Omniglot dataset
koenvandesande May 24, 2019
3f10d56
Add common get_path_or_fp function
koenvandesande May 24, 2019
534d35d
Forgot one spot
koenvandesande May 24, 2019
26227de
Remove syntax unsupported by Python 2, replace argument with code tha…
koenvandesande May 28, 2019
547b618
Delete _C.cp37-win_amd64.pyd
koenvandesande May 28, 2019
559a5cf
Fixes and extra unit tests
May 28, 2019
741e3bb
Fixes and extra unit tests
koenvandesande May 28, 2019
f037fe9
Merge branch 'read_zipped_data' of github.com:koenvandesande/vision i…
koenvandesande May 28, 2019
c0d4dbf
Fix
koenvandesande May 28, 2019
481d45c
Fix
koenvandesande May 28, 2019
32b2311
Need to rewrite Omniglot ZIP-file because it uses compression
koenvandesande May 28, 2019
255a6f9
Fix flake8
koenvandesande May 28, 2019
8adf9af
Omniglot depends on pandas, and that is tested now in test_datasets
koenvandesande May 28, 2019
7753710
Fix
koenvandesande May 28, 2019
bfa7510
Add extra check
koenvandesande May 28, 2019
afd2d04
Refactor
koenvandesande May 28, 2019
2b7a044
Add test
koenvandesande May 28, 2019
f4f1905
Flake8
koenvandesande May 28, 2019
05e8401
tqdm too old in Travis?
koenvandesande May 28, 2019
5280d5a
For the 'smart' person who uses a symlink to their data, and then mod…
koenvandesande May 28, 2019
28b74b9
Fix mistake
koenvandesande May 29, 2019
f0f0f5d
Merge branch 'master' into read_zipped_data
koenvandesande May 29, 2019
bbfd16a
Flake8
koenvandesande May 29, 2019
39ea27a
Fix extension
koenvandesande May 29, 2019
2ac26ce
Add ZippedImageFolder class which reads a zipped version of the data …
koenvandesande May 31, 2019
629c851
Merge branch 'master' into read_zipped_data
koenvandesande Jul 12, 2019
9396c35
Fix flake8
koenvandesande Jul 12, 2019
48894bf
Update test_zippedfolder.py
koenvandesande Jul 12, 2019
0fa8035
Fix test
koenvandesande Jul 13, 2019
c05281a
Fix omniglot
koenvandesande Jul 13, 2019
ef3ba78
0.4.0 packaging (#1212)
ezyang Aug 8, 2019
66bc6f9
Don't build nightlies on 0.4.0 branch.
ezyang Aug 8, 2019
a1ed206
Refactor version suffix so conda packages don't get suffixes. (#1218)…
ezyang Aug 8, 2019
7a8b133
Merge remote-tracking branch 'upstream/v0.4.0' into read_zipped_data
koenvandesande Aug 30, 2019
b8c2c5d
Merge remote-tracking branch 'upstream/master' into read_zipped_data
koenvandesande Aug 30, 2019
8df35fa
Merge branch 'master' into read_zipped_data
koenvandesande Oct 4, 2019
d68ce83
Merge branch 'master' into read_zipped_data
koenvandesande Oct 4, 2019
17de30d
Merge branch 'master' into read_zipped_data
koenvandesande Mar 10, 2020
393cfd6
Update config.yml
koenvandesande Mar 10, 2020
f28b324
Remove EOL
koenvandesande Mar 10, 2020
6247d96
Merge branch 'master' into read_zipped_data
koenvandesande Oct 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ before_install:
pip uninstall -y pillow && CC="cc -march=native" pip install --force-reinstall pillow-simd
fi
- pip install future
- pip install pandas tqdm
- pip install pytest pytest-cov codecov


Expand Down
27 changes: 27 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import PIL
import os
import shutil
import tempfile
import unittest
Expand Down Expand Up @@ -33,6 +34,32 @@ def test_fashionmnist(self):
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)

def test_celeba(self):
temp_dir = tempfile.mkdtemp()
ds = torchvision.datasets.CelebA(root=temp_dir, download=True)
assert len(ds) == 162770
assert ds[40711] is not None

# 2nd time, the ZIP file will be detected (because now it has been downloaded)
ds2 = torchvision.datasets.CelebA(root=temp_dir, download=True)
assert ds2.root_zip is not None, "Transparant ZIP reading support broken: ZIP file not found"
assert len(ds2) == 162770
assert ds2[40711] is not None
shutil.rmtree(temp_dir)

def test_omniglot(self):
temp_dir = tempfile.mkdtemp()
ds = torchvision.datasets.Omniglot(root=temp_dir, download=True)
assert len(ds) == 19280
assert ds[4071] is not None

# 2nd time, the ZIP file will be detected (because now it has been downloaded)
ds2 = torchvision.datasets.Omniglot(root=temp_dir, download=True)
assert ds2.root_zip is not None, "Transparant ZIP reading support broken: ZIP file not found"
assert len(ds2) == 19280
assert ds2[4071] is not None
shutil.rmtree(temp_dir)


if __name__ == '__main__':
unittest.main()
47 changes: 47 additions & 0 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,53 @@ def test_download_url_retry_http(self):
assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.'
shutil.rmtree(temp_dir)

def test_convert_zip_to_uncompressed_zip(self):
temp_dir = tempfile.mkdtemp()
temp_filename = os.path.join(temp_dir, "convert.zip")
temp_filename2 = os.path.join(temp_dir, "converted.zip")
try:
z = zipfile.ZipFile(temp_filename, "w", zipfile.ZIP_DEFLATED, allowZip64=True)
z.write(TEST_FILE, "hopper.jpg")
z.write(TEST_FILE)
z.write(TEST_FILE, "hopper79.jpg")
z.write(TEST_FILE, "somepath/hopper.jpg")
z.close()

utils.convert_zip_to_uncompressed_zip(temp_filename, temp_filename2)
with zipfile.ZipFile(temp_filename2) as u:
for info in u.infolist():
assert info.compress_type == zipfile.ZIP_STORED
_ = utils.ZipLookup(temp_filename2)
finally:
shutil.rmtree(temp_dir)

def test_ziplookup(self):
temp_dir = tempfile.mkdtemp()
temp_filename = os.path.join(temp_dir, "ziplookup.zip")
try:
z = zipfile.ZipFile(temp_filename, "w", zipfile.ZIP_STORED, allowZip64=True)
z.write(TEST_FILE, "hopper.jpg")
z.write(TEST_FILE)
z.write(TEST_FILE, "hopper79.jpg")
z.write(TEST_FILE, "somepath/hopper.jpg")
z.close()

lookup = utils.ZipLookup(temp_filename)
f = lookup["hopper.jpg"]
assert f.name.endswith(".jpg")
f = lookup["somepath/hopper.jpg"]
assert f.name.endswith(".jpg")
try:
f = lookup["does_not_exist.jpg"]
assert False, "Should not return something for non-existant file"
except KeyError:
pass
assert "hopper.jpg" in lookup.keys()
assert "somepath/hopper.jpg" in lookup.keys()
del lookup
finally:
shutil.rmtree(temp_dir)

def test_extract_zip(self):
temp_dir = tempfile.mkdtemp()
with tempfile.NamedTemporaryFile(suffix='.zip') as f:
Expand Down
45 changes: 45 additions & 0 deletions test/test_zippedfolder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import unittest

import tempfile
import os
import shutil
import zipfile

from torchvision.datasets import ZippedImageFolder
from torch._utils_internal import get_file_path_2


class Tester(unittest.TestCase):
root = os.path.normpath(get_file_path_2('test/assets/dataset/'))
classes = ['a', 'b']
class_a_images = [os.path.normpath(get_file_path_2(os.path.join('test/assets/dataset/a/', path)))
for path in ['a1.png', 'a2.png', 'a3.png']]
class_b_images = [os.path.normpath(get_file_path_2(os.path.join('test/assets/dataset/b/', path)))
for path in ['b1.png', 'b2.png', 'b3.png', 'b4.png']]

def test_zipped_image_folder(self):
temp_dir = tempfile.mkdtemp()
temp_filename = os.path.join(temp_dir, "dataset.zip")
try:
zf = zipfile.ZipFile(temp_filename, "w", zipfile.ZIP_STORED, allowZip64=True)
for dirname, subdirs, files in os.walk(Tester.root):
for filename in files:
zf.write(os.path.join(dirname, filename),
os.path.relpath(os.path.join(dirname, filename), Tester.root))
zf.close()

dataset = ZippedImageFolder(root=temp_filename)
for cls in Tester.classes:
self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
imgs_a = [(img_path.replace('test/assets/dataset/', ''), class_a_idx)for img_path in Tester.class_a_images]
imgs_b = [(img_path.replace('test/assets/dataset/', ''), class_b_idx)for img_path in Tester.class_b_images]
imgs = sorted(imgs_a + imgs_b)
self.assertEqual(imgs, dataset.imgs)
finally:
shutil.rmtree(temp_dir)


if __name__ == '__main__':
unittest.main()
3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .lsun import LSUN, LSUNClass
from .folder import ImageFolder, DatasetFolder
from .zippedfolder import ZippedImageFolder
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10
Expand All @@ -21,7 +22,7 @@
from .usps import USPS

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
'ImageFolder', 'DatasetFolder', 'ZippedImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
Expand Down
25 changes: 13 additions & 12 deletions torchvision/datasets/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def __init__(self, root,
transform=None, target_transform=None,
download=False):
import pandas
super(CelebA, self).__init__(root)
root = os.path.join(root, self.base_folder)
super(CelebA, self).__init__(root, root_zipfilename=os.path.join(root, "img_align_celeba.zip"))
self.split = split
if isinstance(target_type, list):
self.target_type = target_type
Expand Down Expand Up @@ -82,19 +83,19 @@ def __init__(self, root,
raise ValueError('Wrong split entered! Please use split="train" '
'or split="valid" or split="test"')

with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f:
with open(os.path.join(self.root, "list_eval_partition.txt"), "r") as f:
splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)

with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f:
with open(os.path.join(self.root, "identity_CelebA.txt"), "r") as f:
self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)

with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f:
with open(os.path.join(self.root, "list_bbox_celeba.txt"), "r") as f:
self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0)

with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f:
with open(os.path.join(self.root, "list_landmarks_align_celeba.txt"), "r") as f:
self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1)

with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f:
with open(os.path.join(self.root, "list_attr_celeba.txt"), "r") as f:
self.attr = pandas.read_csv(f, delim_whitespace=True, header=1)

mask = (splits[1] == split)
Expand All @@ -107,15 +108,15 @@ def __init__(self, root,

def _check_integrity(self):
for (_, md5, filename) in self.file_list:
fpath = os.path.join(self.root, self.base_folder, filename)
fpath = os.path.join(self.root, filename)
_, ext = os.path.splitext(filename)
# Allow original archive to be deleted (zip and 7z)
# Only need the extracted images
if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
return False

# Should check a hash of the images
return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
return os.path.isdir(os.path.join(self.root, "img_align_celeba"))

def download(self):
import zipfile
Expand All @@ -125,13 +126,13 @@ def download(self):
return

for (file_id, md5, filename) in self.file_list:
download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
download_file_from_google_drive(file_id, self.root, filename, md5)

with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
f.extractall(os.path.join(self.root, self.base_folder))
with zipfile.ZipFile(os.path.join(self.root, "img_align_celeba.zip"), "r") as f:
f.extractall(self.root)

def __getitem__(self, index):
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
X = PIL.Image.open(self.get_path_or_fp("img_align_celeba", self.filename[index]))

target = []
for t in self.target_type:
Expand Down
10 changes: 6 additions & 4 deletions torchvision/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class CocoCaptions(VisionDataset):
"""

def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
super(CocoCaptions, self).__init__(root, transforms, transform, target_transform)
super(CocoCaptions, self).__init__(root, transforms, transform, target_transform,
root_zipfilename=root + ".zip")
from pycocotools.coco import COCO
self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys()))
Expand All @@ -65,7 +66,7 @@ def __getitem__(self, index):

path = coco.loadImgs(img_id)[0]['file_name']

img = Image.open(os.path.join(self.root, path)).convert('RGB')
img = Image.open(self.get_path_or_fp(path)).convert('RGB')

if self.transforms is not None:
img, target = self.transforms(img, target)
Expand All @@ -89,7 +90,8 @@ class CocoDetection(VisionDataset):
"""

def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
super(CocoDetection, self).__init__(root, transforms, transform, target_transform,
root_zipfilename=root + ".zip")
from pycocotools.coco import COCO
self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys()))
Expand All @@ -109,7 +111,7 @@ def __getitem__(self, index):

path = coco.loadImgs(img_id)[0]['file_name']

img = Image.open(os.path.join(self.root, path)).convert('RGB')
img = Image.open(self.get_path_or_fp(path)).convert('RGB')
if self.transforms is not None:
img, target = self.transforms(img, target)

Expand Down
17 changes: 10 additions & 7 deletions torchvision/datasets/omniglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from os.path import join
import os
from .vision import VisionDataset
from .utils import download_and_extract, check_integrity, list_dir, list_files
from .utils import download_and_extract, check_integrity, list_dir, list_files, convert_zip_to_uncompressed_zip


class Omniglot(VisionDataset):
Expand Down Expand Up @@ -31,10 +31,11 @@ class Omniglot(VisionDataset):
def __init__(self, root, background=True,
transform=None, target_transform=None,
download=False):
super(Omniglot, self).__init__(join(root, self.folder))
self.background = background
super(Omniglot, self).__init__(join(root, self.folder),
root_zipfilename=join(root, self.folder, self._get_target_folder() + ".zip"))
self.transform = transform
self.target_transform = target_transform
self.background = background

if download:
self.download()
Expand Down Expand Up @@ -63,8 +64,8 @@ def __getitem__(self, index):
tuple: (image, target) where target is index of the target character class.
"""
image_name, character_class = self._flat_character_images[index]
image_path = join(self.target_folder, self._characters[character_class], image_name)
image = Image.open(image_path, mode='r').convert('L')
image_path_or_fp = self.get_path_or_fp(self._get_target_folder(), self._characters[character_class], image_name)
image = Image.open(image_path_or_fp, mode='r').convert('L')

if self.transform:
image = self.transform(image)
Expand All @@ -76,7 +77,7 @@ def __getitem__(self, index):

def _check_integrity(self):
zip_filename = self._get_target_folder()
if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]):
if not check_integrity(join(self.root, zip_filename + '.org.zip'), self.zips_md5[zip_filename]):
return False
return True

Expand All @@ -87,8 +88,10 @@ def download(self):

filename = self._get_target_folder()
zip_filename = filename + '.zip'
org_filename = filename + '.org.zip'
url = self.download_url_prefix + '/' + zip_filename
download_and_extract(url, self.root, zip_filename, self.zips_md5[filename])
download_and_extract(url, self.root, org_filename, self.zips_md5[filename])
convert_zip_to_uncompressed_zip(join(self.root, org_filename), join(self.root, zip_filename))

def _get_target_folder(self):
return 'images_background' if self.background else 'images_evaluation'
46 changes: 46 additions & 0 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import zipfile

from torch.utils.model_zoo import tqdm
import io
import struct


def gen_bar_updater():
Expand Down Expand Up @@ -236,3 +238,47 @@ def download_and_extract(url, root, filename, md5=None, remove_finished=False):
download_url(url, root, filename, md5)
print("Extracting {} to {}".format(os.path.join(root, filename), root))
extract_file(os.path.join(root, filename), root, remove_finished)


def convert_zip_to_uncompressed_zip(org_filename, zip_filename):
with zipfile.ZipFile(org_filename, 'r') as zip_file:
with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_STORED) as out_file:
for item in zip_file.infolist():
out_file.writestr(item.filename, zip_file.read(item))


# thread-safe/multiprocessing-safe (unlike a Python ZipFile instance)
class ZipLookup(object):
def __init__(self, filename):
self.root_zip_filename = filename
self.root_zip_lookup = {}

with zipfile.ZipFile(filename, "r") as root_zip:
for info in root_zip.infolist():
if info.filename[-1] == '/':
# skip directories
continue
if info.compress_type != zipfile.ZIP_STORED:
raise ValueError("Only uncompressed ZIP file supported: " + info.filename)
if info.compress_size != info.file_size:
raise ValueError("Must be the same when uncompressed")
self.root_zip_lookup[info.filename] = (info.header_offset, info.compress_size)

def __getitem__(self, path):
z = open(self.root_zip_filename, "rb")
header_offset, size = self.root_zip_lookup[path]

z.seek(header_offset)
fheader = z.read(zipfile.sizeFileHeader)
fheader = struct.unpack(zipfile.structFileHeader, fheader)
offset = header_offset + zipfile.sizeFileHeader + fheader[zipfile._FH_FILENAME_LENGTH] + \
fheader[zipfile._FH_EXTRA_FIELD_LENGTH]

z.seek(offset)
f = io.BytesIO(z.read(size))
f.name = path
z.close()
return f

def keys(self):
return self.root_zip_lookup.keys()
Loading