forked from fillassuncao/denser-models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfashion_mnist.py
41 lines (34 loc) · 2.21 KB
/
fashion_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import sys
if sys.version_info >= (3,0):
import urllib.request
else:
import urllib
import gzip
import idx2numpy
def load_data():
"""Loads the Fashion MNIST dataset.
# Arguments
path: path where to cache the dataset locally
(relative to ~/.keras/datasets).
# Returns
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
"""
if sys.version_info >= (3,0):
path_x_train = urllib.request.urlretrieve('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz', 'train-images-idx3-ubyte.gz')
path_y_train = urllib.request.urlretrieve('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz', 'train-labels-idx1-ubyte.gz')
path_x_test = urllib.request.urlretrieve('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz', 't10k-images-idx3-ubyte.gz')
path_y_test = urllib.request.urlretrieve('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz', 't10k-labels-idx1-ubyte.gz')
else:
path_x_train = urllib.urlretrieve('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz', 'train-images-idx3-ubyte.gz')
path_y_train = urllib.urlretrieve('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz', 'train-labels-idx1-ubyte.gz')
path_x_test = urllib.urlretrieve('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz', 't10k-images-idx3-ubyte.gz')
path_y_test = urllib.urlretrieve('http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz', 't10k-labels-idx1-ubyte.gz')
with gzip.open('train-images-idx3-ubyte.gz', 'rb') as f:
x_train = idx2numpy.convert_from_string(f.read())
with gzip.open('train-labels-idx1-ubyte.gz', 'rb') as f:
y_train = idx2numpy.convert_from_string(f.read())
with gzip.open('t10k-images-idx3-ubyte.gz', 'rb') as f:
x_test = idx2numpy.convert_from_string(f.read())
with gzip.open('t10k-labels-idx1-ubyte.gz', 'rb') as f:
y_test = idx2numpy.convert_from_string(f.read())
return (x_train, y_train), (x_test, y_test)