Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read COCO dataset from ZIP file #950

Closed
wants to merge 51 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
e7f6f66
Read COCO dataset images from its zipfile when it is there
koenvandesande May 23, 2019
5aebeae
Also do it for CocoCaptions
koenvandesande May 23, 2019
7803588
Move code into utils.py, remove the magic constants and import them i…
koenvandesande May 24, 2019
3e5e03d
Add test for zip lookup class
koenvandesande May 24, 2019
4e39618
Fix for Python versions < 3.6
koenvandesande May 24, 2019
9a59666
Generalize to CelebA, move part of shared logic into VisionDataset
koenvandesande May 24, 2019
0bc30f2
Fix import
koenvandesande May 24, 2019
a8b483a
flake8 fixes
koenvandesande May 24, 2019
26d51d0
Simplify implementation of ZipLookup by not keeping file descriptor open
koenvandesande May 24, 2019
29d7df8
Remove unused import
koenvandesande May 24, 2019
46ceaf0
Support reading images from ZIP for Omniglot dataset
koenvandesande May 24, 2019
3f10d56
Add common get_path_or_fp function
koenvandesande May 24, 2019
534d35d
Forgot one spot
koenvandesande May 24, 2019
26227de
Remove syntax unsupported by Python 2, replace argument with code tha…
koenvandesande May 28, 2019
547b618
Delete _C.cp37-win_amd64.pyd
koenvandesande May 28, 2019
559a5cf
Fixes and extra unit tests
May 28, 2019
741e3bb
Fixes and extra unit tests
koenvandesande May 28, 2019
f037fe9
Merge branch 'read_zipped_data' of github.com:koenvandesande/vision i…
koenvandesande May 28, 2019
c0d4dbf
Fix
koenvandesande May 28, 2019
481d45c
Fix
koenvandesande May 28, 2019
32b2311
Need to rewrite Omniglot ZIP-file because it uses compression
koenvandesande May 28, 2019
255a6f9
Fix flake8
koenvandesande May 28, 2019
8adf9af
Omniglot depends on pandas, and that is tested now in test_datasets
koenvandesande May 28, 2019
7753710
Fix
koenvandesande May 28, 2019
bfa7510
Add extra check
koenvandesande May 28, 2019
afd2d04
Refactor
koenvandesande May 28, 2019
2b7a044
Add test
koenvandesande May 28, 2019
f4f1905
Flake8
koenvandesande May 28, 2019
05e8401
tqdm too old in Travis?
koenvandesande May 28, 2019
5280d5a
For the 'smart' person who uses a symlink to their data, and then mod…
koenvandesande May 28, 2019
28b74b9
Fix mistake
koenvandesande May 29, 2019
f0f0f5d
Merge branch 'master' into read_zipped_data
koenvandesande May 29, 2019
bbfd16a
Flake8
koenvandesande May 29, 2019
39ea27a
Fix extension
koenvandesande May 29, 2019
2ac26ce
Add ZippedImageFolder class which reads a zipped version of the data …
koenvandesande May 31, 2019
629c851
Merge branch 'master' into read_zipped_data
koenvandesande Jul 12, 2019
9396c35
Fix flake8
koenvandesande Jul 12, 2019
48894bf
Update test_zippedfolder.py
koenvandesande Jul 12, 2019
0fa8035
Fix test
koenvandesande Jul 13, 2019
c05281a
Fix omniglot
koenvandesande Jul 13, 2019
ef3ba78
0.4.0 packaging (#1212)
ezyang Aug 8, 2019
66bc6f9
Don't build nightlies on 0.4.0 branch.
ezyang Aug 8, 2019
a1ed206
Refactor version suffix so conda packages don't get suffixes. (#1218)…
ezyang Aug 8, 2019
7a8b133
Merge remote-tracking branch 'upstream/v0.4.0' into read_zipped_data
koenvandesande Aug 30, 2019
b8c2c5d
Merge remote-tracking branch 'upstream/master' into read_zipped_data
koenvandesande Aug 30, 2019
8df35fa
Merge branch 'master' into read_zipped_data
koenvandesande Oct 4, 2019
d68ce83
Merge branch 'master' into read_zipped_data
koenvandesande Oct 4, 2019
17de30d
Merge branch 'master' into read_zipped_data
koenvandesande Mar 10, 2020
393cfd6
Update config.yml
koenvandesande Mar 10, 2020
f28b324
Remove EOL
koenvandesande Mar 10, 2020
6247d96
Merge branch 'master' into read_zipped_data
koenvandesande Oct 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
import torchvision.datasets.utils as utils
import unittest
import zipfile

TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'assets', 'grace_hopper_517x606.jpg')
Expand Down Expand Up @@ -41,6 +42,31 @@ def test_download_url_retry_http(self):
assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.'
shutil.rmtree(temp_dir)

def test_ziplookup(self):
temp_dir = tempfile.mkdtemp()
temp_filename = os.path.join(temp_dir, "ziplookup.zip")
try:
z = zipfile.ZipFile(temp_filename, "w", zipfile.ZIP_STORED, allowZip64=True)
z.write(TEST_FILE, "hopper.jpg")
z.write(TEST_FILE)
z.write(TEST_FILE, "hopper79.jpg")
z.write(TEST_FILE, "somepath/hopper.jpg")
z.close()

lookup = utils.ZipLookup(temp_filename)
f = lookup["hopper.jpg"]
assert f.name.endswith(".jpg")
f = lookup["somepath/hopper.jpg"]
assert f.name.endswith(".jpg")
try:
f = lookup["does_not_exist.jpg"]
assert False, "Should not return something for non-existant file"
except KeyError:
pass
del lookup
finally:
shutil.rmtree(temp_dir)


if __name__ == '__main__':
unittest.main()
9 changes: 7 additions & 2 deletions torchvision/datasets/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def __init__(self, root,
transform=None, target_transform=None,
download=False):
import pandas
super(CelebA, self).__init__(root)
super(CelebA, self).__init__(root,
root_zipfilename=os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))
self.split = split
if isinstance(target_type, list):
self.target_type = target_type
Expand Down Expand Up @@ -131,7 +132,11 @@ def download(self):
f.extractall(os.path.join(self.root, self.base_folder))

def __getitem__(self, index):
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
if self.root_zip is not None:
f = self.root_zip[os.path.join("img_align_celeba", self.filename[index])]
X = PIL.Image.open(f)
else:
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))

target = []
for t in self.target_type:
Expand Down
18 changes: 14 additions & 4 deletions torchvision/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class CocoCaptions(VisionDataset):
"""

def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
super(CocoCaptions, self).__init__(root, transforms, transform, target_transform)
super(CocoCaptions, self).__init__(root, transforms, transform, target_transform,
root_zipfilename=self.root + ".zip")
from pycocotools.coco import COCO
self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys()))
Expand All @@ -65,7 +66,11 @@ def __getitem__(self, index):

path = coco.loadImgs(img_id)[0]['file_name']

img = Image.open(os.path.join(self.root, path)).convert('RGB')
if self.root_zip is not None:
f = self.root_zip[os.path.split(self.root)[1] + "/" + path]
img = Image.open(f).convert('RGB')
else:
img = Image.open(os.path.join(self.root, path)).convert('RGB')
koenvandesande marked this conversation as resolved.
Show resolved Hide resolved

if self.transforms is not None:
img, target = self.transforms(img, target)
Expand All @@ -89,7 +94,8 @@ class CocoDetection(VisionDataset):
"""

def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
super(CocoDetection, self).__init__(root, transforms, transform, target_transform,
root_zipfilename=self.root + ".zip")
from pycocotools.coco import COCO
self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys()))
Expand All @@ -109,7 +115,11 @@ def __getitem__(self, index):

path = coco.loadImgs(img_id)[0]['file_name']

img = Image.open(os.path.join(self.root, path)).convert('RGB')
if self.root_zip is not None:
f = self.root_zip[os.path.split(self.root)[1] + "/" + path]
img = Image.open(f).convert('RGB')
else:
img = Image.open(os.path.join(self.root, path)).convert('RGB')
if self.transforms is not None:
img, target = self.transforms(img, target)

Expand Down
37 changes: 37 additions & 0 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import hashlib
import errno
from torch.utils.model_zoo import tqdm
import zipfile
import io
import struct


def gen_bar_updater():
Expand Down Expand Up @@ -189,3 +192,37 @@ def _save_response_content(response, destination, chunk_size=32768):
progress += len(chunk)
pbar.update(progress - pbar.n)
pbar.close()


# thread-safe/multiprocessing-safe (unlike a Python ZipFile instance)
class ZipLookup(object):
def __init__(self, filename):
self.root_zip_filename = filename
self.root_zip_lookup = {}

with zipfile.ZipFile(filename, "r") as root_zip:
for info in root_zip.infolist():
koenvandesande marked this conversation as resolved.
Show resolved Hide resolved
if info.filename[-1] == '/':
# skip directories
continue
if info.compress_type != zipfile.ZIP_STORED:
raise ValueError("Only uncompressed ZIP file supported: " + info.filename)
if info.compress_size != info.file_size:
raise ValueError("Must be the same when uncompressed")
self.root_zip_lookup[info.filename] = (info.header_offset, info.compress_size)

def __getitem__(self, path):
z = open(self.root_zip_filename, "rb")
header_offset, size = self.root_zip_lookup[path]

z.seek(header_offset)
fheader = z.read(zipfile.sizeFileHeader)
fheader = struct.unpack(zipfile.structFileHeader, fheader)
offset = header_offset + zipfile.sizeFileHeader + fheader[zipfile._FH_FILENAME_LENGTH] + \
fheader[zipfile._FH_EXTRA_FIELD_LENGTH]

z.seek(offset)
f = io.BytesIO(z.read(size))
f.name = path
z.close()
return f
9 changes: 8 additions & 1 deletion torchvision/datasets/vision.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import torch
import torch.utils.data as data
from .utils import ZipLookup


class VisionDataset(data.Dataset):
_repr_indent = 4

def __init__(self, root, transforms=None, transform=None, target_transform=None):
def __init__(self, root, transforms=None, transform=None, target_transform=None, root_zipfilename=None):
if isinstance(root, torch._six.string_classes):
root = os.path.expanduser(root)
self.root = root
Expand All @@ -25,6 +26,12 @@ def __init__(self, root, transforms=None, transform=None, target_transform=None)
transforms = StandardTransform(transform, target_transform)
self.transforms = transforms

self.root_zip = None
self.root_zipfilename = root_zipfilename
if self.root_zipfilename is not None and os.path.exists(self.root_zipfilename):
self.root_zip = ZipLookup(self.root_zipfilename)
print("Using ZIP file for data source:", self.root_zipfilename)

def __getitem__(self, index):
raise NotImplementedError

Expand Down