-
Notifications
You must be signed in to change notification settings - Fork 37
/
datasets.py
155 lines (131 loc) · 5.78 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import torch
import codecs
from torchvision.datasets.mnist import MNIST
class EMNIST(MNIST):
"""`EMNIST <https://www.nist.gov/itl/iad/image-group/emnist-dataset/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
which one to use.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
url = 'http://biometrics.nist.gov/cs_links/EMNIST/gzip.zip'
splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')
def __init__(self, root, split, **kwargs):
if split not in self.splits:
raise RuntimeError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
self.split = split
self.training_file = self._training_file(split)
self.test_file = self._test_file(split)
super(EMNIST, self).__init__(root, **kwargs)
def _training_file(self, split):
return 'training_{}.pt'.format(split)
def _test_file(self, split):
return 'test_{}.pt'.format(split)
def download(self):
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip
import shutil
import zipfile
if self._check_exists():
return
# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
print('Downloading ' + self.url)
data = urllib.request.urlopen(self.url)
filename = self.url.rpartition('/')[2]
raw_folder = os.path.join(self.root, self.raw_folder)
file_path = os.path.join(raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())
print('Extracting zip archive')
with zipfile.ZipFile(file_path) as zip_f:
zip_f.extractall(raw_folder)
os.unlink(file_path)
gzip_folder = os.path.join(raw_folder, 'gzip')
for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'):
print('Extracting ' + gzip_file)
with open(os.path.join(raw_folder, gzip_file.replace('.gz', '')), 'wb') as out_f, \
gzip.GzipFile(os.path.join(gzip_folder, gzip_file)) as zip_f:
out_f.write(zip_f.read())
shutil.rmtree(gzip_folder)
# process and save as torch files
for split in self.splits:
print('Processing ' + split)
training_set = (
read_image_file(os.path.join(raw_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(raw_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
)
test_set = (
read_image_file(os.path.join(raw_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(raw_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
)
with open(os.path.join(self.root, self.processed_folder, self._training_file(split)), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.root, self.processed_folder, self._test_file(split)), 'wb') as f:
torch.save(test_set, f)
print('Done!')
class AgirEcole(MNIST):
""" Agir pour l'ecole dataset
"""
splits = ("val", "test")
def __init__(self, root, split, **kwargs):
if split not in self.splits:
raise RuntimeError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
self.split = split
self.training_file = self._training_file(split)
self.test_file = self._test_file(split)
super(AgirEcole, self).__init__(root, **kwargs)
def _training_file(self, split):
return 'data-{}.pt'.format(split)
def _test_file(self, split):
return 'data-{}.pt'.format(split)
def get_int(b):
return int(codecs.encode(b, 'hex'), 16)
def read_label_file(path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2049
length = get_int(data[4:8])
parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
return torch.from_numpy(parsed).view(length).long()
def read_image_file(path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2051
length = get_int(data[4:8])
num_rows = get_int(data[8:12])
num_cols = get_int(data[12:16])
images = []
parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
return torch.from_numpy(parsed).view(length, num_rows, num_cols)