From dbdacd40300ba707cf9e63948ce2c39fc59adb72 Mon Sep 17 00:00:00 2001 From: sjoke Date: Tue, 20 Mar 2018 21:54:24 +0800 Subject: [PATCH 1/2] add a way to read four data files from local path --- torchvision/datasets/mnist.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index a467ebee554..8e7e5b811e4 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -36,15 +36,20 @@ class MNIST(data.Dataset): training_file = 'training.pt' test_file = 'test.pt' - def __init__(self, root, train=True, transform=None, target_transform=None, download=False): + def __init__(self, root, train=True, transform=None, target_transform=None, \ + download=False, from_local=False): self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform self.train = train # training set or test set + self.from_local = from_local if download: self.download() + elif self.from_local: + self.download() + if not self._check_exists(): raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') @@ -94,7 +99,7 @@ def _check_exists(self): def download(self): """Download the MNIST data if it doesn't exist in processed_folder already.""" from six.moves import urllib - import gzip + import gzip, shutil if self._check_exists(): return @@ -110,12 +115,18 @@ def download(self): raise for url in self.urls: - print('Downloading ' + url) - data = urllib.request.urlopen(url) filename = url.rpartition('/')[2] file_path = os.path.join(self.root, self.raw_folder, filename) - with open(file_path, 'wb') as f: - f.write(data.read()) + + if self.from_local: + tmp_file_path = os.path.join(self.root, filename) + shutil.move(tmp_file_path, os.path.join(self.root, self.raw_folder)) + else: + print('Downloading ' + url) + data = urllib.request.urlopen(url) + with open(file_path, 'wb') as f: + f.write(data.read()) + with open(file_path.replace('.gz', ''), 'wb') as out_f, \ gzip.GzipFile(file_path) as zip_f: out_f.write(zip_f.read()) From d71d857717325149101e4b563a32dbd6cb296c0e Mon Sep 17 00:00:00 2001 From: sjoke Date: Wed, 21 Mar 2018 22:28:54 +0800 Subject: [PATCH 2/2] modify --- torchvision/datasets/mnist.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 8e7e5b811e4..7bbb4d54214 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -36,7 +36,7 @@ class MNIST(data.Dataset): training_file = 'training.pt' test_file = 'test.pt' - def __init__(self, root, train=True, transform=None, target_transform=None, \ + def __init__(self, root, train=True, transform=None, target_transform=None, download=False, from_local=False): self.root = os.path.expanduser(root) self.transform = transform @@ -99,7 +99,8 @@ def _check_exists(self): def download(self): """Download the MNIST data if it doesn't exist in processed_folder already.""" from six.moves import urllib - import gzip, shutil + import gzip + import shutil if self._check_exists(): return