Skip to content

Read COCO dataset from ZIP file #950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 51 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
e7f6f66
Read COCO dataset images from its zipfile when it is there
koenvandesande May 23, 2019
5aebeae
Also do it for CocoCaptions
koenvandesande May 23, 2019
7803588
Move code into utils.py, remove the magic constants and import them i…
koenvandesande May 24, 2019
3e5e03d
Add test for zip lookup class
koenvandesande May 24, 2019
4e39618
Fix for Python versions < 3.6
koenvandesande May 24, 2019
9a59666
Generalize to CelebA, move part of shared logic into VisionDataset
koenvandesande May 24, 2019
0bc30f2
Fix import
koenvandesande May 24, 2019
a8b483a
flake8 fixes
koenvandesande May 24, 2019
26d51d0
Simplify implementation of ZipLookup by not keeping file descriptor open
koenvandesande May 24, 2019
29d7df8
Remove unused import
koenvandesande May 24, 2019
46ceaf0
Support reading images from ZIP for Omniglot dataset
koenvandesande May 24, 2019
3f10d56
Add common get_path_or_fp function
koenvandesande May 24, 2019
534d35d
Forgot one spot
koenvandesande May 24, 2019
26227de
Remove syntax unsupported by Python 2, replace argument with code tha…
koenvandesande May 28, 2019
547b618
Delete _C.cp37-win_amd64.pyd
koenvandesande May 28, 2019
559a5cf
Fixes and extra unit tests
May 28, 2019
741e3bb
Fixes and extra unit tests
koenvandesande May 28, 2019
f037fe9
Merge branch 'read_zipped_data' of github.com:koenvandesande/vision i…
koenvandesande May 28, 2019
c0d4dbf
Fix
koenvandesande May 28, 2019
481d45c
Fix
koenvandesande May 28, 2019
32b2311
Need to rewrite Omniglot ZIP-file because it uses compression
koenvandesande May 28, 2019
255a6f9
Fix flake8
koenvandesande May 28, 2019
8adf9af
Omniglot depends on pandas, and that is tested now in test_datasets
koenvandesande May 28, 2019
7753710
Fix
koenvandesande May 28, 2019
bfa7510
Add extra check
koenvandesande May 28, 2019
afd2d04
Refactor
koenvandesande May 28, 2019
2b7a044
Add test
koenvandesande May 28, 2019
f4f1905
Flake8
koenvandesande May 28, 2019
05e8401
tqdm too old in Travis?
koenvandesande May 28, 2019
5280d5a
For the 'smart' person who uses a symlink to their data, and then mod…
koenvandesande May 28, 2019
28b74b9
Fix mistake
koenvandesande May 29, 2019
f0f0f5d
Merge branch 'master' into read_zipped_data
koenvandesande May 29, 2019
bbfd16a
Flake8
koenvandesande May 29, 2019
39ea27a
Fix extension
koenvandesande May 29, 2019
2ac26ce
Add ZippedImageFolder class which reads a zipped version of the data …
koenvandesande May 31, 2019
629c851
Merge branch 'master' into read_zipped_data
koenvandesande Jul 12, 2019
9396c35
Fix flake8
koenvandesande Jul 12, 2019
48894bf
Update test_zippedfolder.py
koenvandesande Jul 12, 2019
0fa8035
Fix test
koenvandesande Jul 13, 2019
c05281a
Fix omniglot
koenvandesande Jul 13, 2019
ef3ba78
0.4.0 packaging (#1212)
ezyang Aug 8, 2019
66bc6f9
Don't build nightlies on 0.4.0 branch.
ezyang Aug 8, 2019
a1ed206
Refactor version suffix so conda packages don't get suffixes. (#1218)…
ezyang Aug 8, 2019
7a8b133
Merge remote-tracking branch 'upstream/v0.4.0' into read_zipped_data
koenvandesande Aug 30, 2019
b8c2c5d
Merge remote-tracking branch 'upstream/master' into read_zipped_data
koenvandesande Aug 30, 2019
8df35fa
Merge branch 'master' into read_zipped_data
koenvandesande Oct 4, 2019
d68ce83
Merge branch 'master' into read_zipped_data
koenvandesande Oct 4, 2019
17de30d
Merge branch 'master' into read_zipped_data
koenvandesande Mar 10, 2020
393cfd6
Update config.yml
koenvandesande Mar 10, 2020
f28b324
Remove EOL
koenvandesande Mar 10, 2020
6247d96
Merge branch 'master' into read_zipped_data
koenvandesande Oct 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ before_install:
pip uninstall -y pillow && CC="cc -march=native" pip install --force-reinstall pillow-simd
fi
- pip install future
- pip install pandas tqdm
- pip install pytest pytest-cov codecov
- pip install typing
- |
Expand Down
28 changes: 28 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys
import os
import shutil
import tempfile
import unittest
from unittest import mock
import numpy as np
Expand Down Expand Up @@ -230,6 +232,32 @@ def test_svhn(self, mock_check):
dataset = torchvision.datasets.SVHN(root, split="extra")
self.generic_classification_dataset_test(dataset, num_images=2)

def test_celeba(self):
temp_dir = tempfile.mkdtemp()
ds = torchvision.datasets.CelebA(root=temp_dir, download=True)
assert len(ds) == 162770
assert ds[40711] is not None

# 2nd time, the ZIP file will be detected (because now it has been downloaded)
ds2 = torchvision.datasets.CelebA(root=temp_dir, download=True)
assert ds2.root_zip is not None, "Transparant ZIP reading support broken: ZIP file not found"
assert len(ds2) == 162770
assert ds2[40711] is not None
shutil.rmtree(temp_dir)

def test_omniglot(self):
temp_dir = tempfile.mkdtemp()
ds = torchvision.datasets.Omniglot(root=temp_dir, download=True)
assert len(ds) == 19280
assert ds[4071] is not None

# 2nd time, the ZIP file will be detected (because now it has been downloaded)
ds2 = torchvision.datasets.Omniglot(root=temp_dir, download=True)
assert ds2.root_zip is not None, "Transparant ZIP reading support broken: ZIP file not found"
assert len(ds2) == 19280
assert ds2[4071] is not None
shutil.rmtree(temp_dir)

@mock.patch('torchvision.datasets.voc.download_extract')
def test_voc_parse_xml(self, mock_download_extract):
with voc_root() as root:
Expand Down
48 changes: 48 additions & 0 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import tempfile
import shutil
import torchvision.datasets.utils as utils
import unittest
import zipfile
Expand Down Expand Up @@ -64,6 +65,53 @@ def test_download_url_dont_exist(self):
with self.assertRaises(URLError):
utils.download_url(url, temp_dir)

def test_convert_zip_to_uncompressed_zip(self):
temp_dir = tempfile.mkdtemp()
temp_filename = os.path.join(temp_dir, "convert.zip")
temp_filename2 = os.path.join(temp_dir, "converted.zip")
try:
z = zipfile.ZipFile(temp_filename, "w", zipfile.ZIP_DEFLATED, allowZip64=True)
z.write(TEST_FILE, "hopper.jpg")
z.write(TEST_FILE)
z.write(TEST_FILE, "hopper79.jpg")
z.write(TEST_FILE, "somepath/hopper.jpg")
z.close()

utils.convert_zip_to_uncompressed_zip(temp_filename, temp_filename2)
with zipfile.ZipFile(temp_filename2) as u:
for info in u.infolist():
assert info.compress_type == zipfile.ZIP_STORED
_ = utils.ZipLookup(temp_filename2)
finally:
shutil.rmtree(temp_dir)

def test_ziplookup(self):
temp_dir = tempfile.mkdtemp()
temp_filename = os.path.join(temp_dir, "ziplookup.zip")
try:
z = zipfile.ZipFile(temp_filename, "w", zipfile.ZIP_STORED, allowZip64=True)
z.write(TEST_FILE, "hopper.jpg")
z.write(TEST_FILE)
z.write(TEST_FILE, "hopper79.jpg")
z.write(TEST_FILE, "somepath/hopper.jpg")
z.close()

lookup = utils.ZipLookup(temp_filename)
f = lookup["hopper.jpg"]
assert f.name.endswith(".jpg")
f = lookup["somepath/hopper.jpg"]
assert f.name.endswith(".jpg")
try:
f = lookup["does_not_exist.jpg"]
assert False, "Should not return something for non-existant file"
except KeyError:
pass
assert "hopper.jpg" in lookup.keys()
assert "somepath/hopper.jpg" in lookup.keys()
del lookup
finally:
shutil.rmtree(temp_dir)

@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_extract_zip(self):
with get_tmp_dir() as temp_dir:
Expand Down
50 changes: 50 additions & 0 deletions test/test_zippedfolder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import unittest

import tempfile
import os
import shutil
import zipfile
from common_utils import get_tmp_dir

from torchvision.datasets import ZippedImageFolder
from torch._utils_internal import get_file_path_2


class Tester(unittest.TestCase):
FAKEDATA_DIR = get_file_path_2(os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata')

def test_zipped_image_folder(self):
temp_dir = tempfile.mkdtemp()
temp_filename = os.path.join(temp_dir, "dataset.zip")
try:
with get_tmp_dir(src=os.path.join(Tester.FAKEDATA_DIR, 'imagefolder')) as root:
classes = sorted(['a', 'b'])
class_a_image_files = [os.path.join(root, 'a', file)
for file in ('a1.png', 'a2.png', 'a3.png')]
class_b_image_files = [os.path.join(root, 'b', file)
for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')]

zf = zipfile.ZipFile(temp_filename, "w", zipfile.ZIP_STORED, allowZip64=True)
for dirname, subdirs, files in os.walk(root):
for filename in files:
zf.write(os.path.join(dirname, filename),
os.path.relpath(os.path.join(dirname, filename), root))
zf.close()

dataset = ZippedImageFolder(root=temp_filename)
for cls in classes:
self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
imgs_a = [(img_path.replace(root + os.path.sep, '').replace(os.path.sep, "/"), class_a_idx)
for img_path in class_a_image_files]
imgs_b = [(img_path.replace(root + os.path.sep, '').replace(os.path.sep, "/"), class_b_idx)
for img_path in class_b_image_files]
imgs = sorted(imgs_a + imgs_b)
self.assertEqual(imgs, dataset.imgs)
finally:
shutil.rmtree(temp_dir)


if __name__ == '__main__':
unittest.main()
3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .lsun import LSUN, LSUNClass
from .folder import ImageFolder, DatasetFolder
from .zippedfolder import ZippedImageFolder
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10
Expand All @@ -25,7 +26,7 @@
from .places365 import Places365

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
'ImageFolder', 'DatasetFolder', 'ZippedImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'QMNIST',
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
Expand Down
19 changes: 10 additions & 9 deletions torchvision/datasets/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def __init__(
download: bool = False,
) -> None:
import pandas
super(CelebA, self).__init__(root, transform=transform,
root = os.path.join(root, self.base_folder)
super(CelebA, self).__init__(root, root_zipfilename=os.path.join(root, "img_align_celeba.zip"), transform=transform,
target_transform=target_transform)
self.split = split
if isinstance(target_type, list):
Expand All @@ -86,7 +87,7 @@ def __init__(
split_ = split_map[verify_str_arg(split.lower(), "split",
("train", "valid", "test", "all"))]

fn = partial(os.path.join, self.root, self.base_folder)
fn = partial(os.path.join, self.root)
splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
identity = pandas.read_csv(fn("identity_CelebA.txt"), delim_whitespace=True, header=None, index_col=0)
bbox = pandas.read_csv(fn("list_bbox_celeba.txt"), delim_whitespace=True, header=1, index_col=0)
Expand All @@ -105,15 +106,15 @@ def __init__(

def _check_integrity(self) -> bool:
for (_, md5, filename) in self.file_list:
fpath = os.path.join(self.root, self.base_folder, filename)
fpath = os.path.join(self.root, filename)
_, ext = os.path.splitext(filename)
# Allow original archive to be deleted (zip and 7z)
# Only need the extracted images
if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
return False

# Should check a hash of the images
return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
return os.path.isdir(os.path.join(self.root, "img_align_celeba"))

def download(self) -> None:
import zipfile
Expand All @@ -123,13 +124,13 @@ def download(self) -> None:
return

for (file_id, md5, filename) in self.file_list:
download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
download_file_from_google_drive(file_id, self.root, filename, md5)

with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
f.extractall(os.path.join(self.root, self.base_folder))
with zipfile.ZipFile(os.path.join(self.root, "img_align_celeba.zip"), "r") as f:
f.extractall(self.root)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
def __getitem__(self, index):
X = PIL.Image.open(self.get_path_or_fp("img_align_celeba", self.filename[index]))

target: Any = []
for t in self.target_type:
Expand Down
10 changes: 6 additions & 4 deletions torchvision/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def __init__(
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
) -> None:
super(CocoCaptions, self).__init__(root, transforms, transform, target_transform)
super(CocoCaptions, self).__init__(root, transforms, transform, target_transform,
root_zipfilename=root + ".zip")
from pycocotools.coco import COCO
self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys()))
Expand All @@ -75,7 +76,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:

path = coco.loadImgs(img_id)[0]['file_name']

img = Image.open(os.path.join(self.root, path)).convert('RGB')
img = Image.open(self.get_path_or_fp(path)).convert('RGB')

if self.transforms is not None:
img, target = self.transforms(img, target)
Expand Down Expand Up @@ -108,7 +109,8 @@ def __init__(
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
) -> None:
super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
super(CocoDetection, self).__init__(root, transforms, transform, target_transform,
root_zipfilename=root + ".zip")
from pycocotools.coco import COCO
self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys()))
Expand All @@ -128,7 +130,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:

path = coco.loadImgs(img_id)[0]['file_name']

img = Image.open(os.path.join(self.root, path)).convert('RGB')
img = Image.open(self.get_path_or_fp(path)).convert('RGB')
if self.transforms is not None:
img, target = self.transforms(img, target)

Expand Down
18 changes: 11 additions & 7 deletions torchvision/datasets/omniglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from typing import Any, Callable, List, Optional, Tuple
from .vision import VisionDataset
from .utils import download_and_extract_archive, check_integrity, list_dir, list_files
from .utils import download_and_extract_archive, check_integrity, list_dir, list_files, convert_zip_to_uncompressed_zip


class Omniglot(VisionDataset):
Expand Down Expand Up @@ -36,9 +36,11 @@ def __init__(
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(Omniglot, self).__init__(join(root, self.folder), transform=transform,
target_transform=target_transform)
self.background = background
super(Omniglot, self).__init__(join(root, self.folder),
root_zipfilename=join(root, self.folder, self._get_target_folder() + ".zip"),
transform=transform,
target_transform=target_transform)

if download:
self.download()
Expand Down Expand Up @@ -67,8 +69,8 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
tuple: (image, target) where target is index of the target character class.
"""
image_name, character_class = self._flat_character_images[index]
image_path = join(self.target_folder, self._characters[character_class], image_name)
image = Image.open(image_path, mode='r').convert('L')
image_path_or_fp = self.get_path_or_fp(self._get_target_folder(), self._characters[character_class], image_name)
image = Image.open(image_path_or_fp, mode='r').convert('L')

if self.transform:
image = self.transform(image)
Expand All @@ -80,7 +82,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:

def _check_integrity(self) -> bool:
zip_filename = self._get_target_folder()
if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]):
if not check_integrity(join(self.root, zip_filename + '.org.zip'), self.zips_md5[zip_filename]):
return False
return True

Expand All @@ -91,8 +93,10 @@ def download(self) -> None:

filename = self._get_target_folder()
zip_filename = filename + '.zip'
org_filename = filename + '.org.zip'
url = self.download_url_prefix + '/' + zip_filename
download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
download_and_extract_archive(url, self.root, filename=org_filename, md5=self.zips_md5[filename])
convert_zip_to_uncompressed_zip(join(self.root, org_filename), join(self.root, zip_filename))

def _get_target_folder(self) -> str:
return 'images_background' if self.background else 'images_evaluation'
47 changes: 47 additions & 0 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

import torch
from torch.utils.model_zoo import tqdm
import io
import struct
from torch._six import PY3


def gen_bar_updater() -> Callable[[int, int, int], None]:
Expand Down Expand Up @@ -260,6 +263,50 @@ def download_and_extract_archive(
extract_archive(archive, extract_root, remove_finished)


def convert_zip_to_uncompressed_zip(org_filename, zip_filename):
with zipfile.ZipFile(org_filename, 'r') as zip_file:
with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_STORED) as out_file:
for item in zip_file.infolist():
out_file.writestr(item.filename, zip_file.read(item))


# thread-safe/multiprocessing-safe (unlike a Python ZipFile instance)
class ZipLookup(object):
def __init__(self, filename):
self.root_zip_filename = filename
self.root_zip_lookup = {}

with zipfile.ZipFile(filename, "r") as root_zip:
for info in root_zip.infolist():
if info.filename[-1] == '/':
# skip directories
continue
if info.compress_type != zipfile.ZIP_STORED:
raise ValueError("Only uncompressed ZIP file supported: " + info.filename)
if info.compress_size != info.file_size:
raise ValueError("Must be the same when uncompressed")
self.root_zip_lookup[info.filename] = (info.header_offset, info.compress_size)

def __getitem__(self, path):
z = open(self.root_zip_filename, "rb")
header_offset, size = self.root_zip_lookup[path]

z.seek(header_offset)
fheader = z.read(zipfile.sizeFileHeader)
fheader = struct.unpack(zipfile.structFileHeader, fheader)
offset = header_offset + zipfile.sizeFileHeader + fheader[zipfile._FH_FILENAME_LENGTH] + \
fheader[zipfile._FH_EXTRA_FIELD_LENGTH]

z.seek(offset)
f = io.BytesIO(z.read(size))
f.name = path
z.close()
return f

def keys(self):
return self.root_zip_lookup.keys()


def iterable_to_str(iterable: Iterable) -> str:
return "'" + "', '".join([str(item) for item in iterable]) + "'"

Expand Down
Loading