|
1 | 1 | import logging |
2 | 2 | import gzip |
3 | 3 | import struct |
4 | | -import urllib.request |
5 | 4 | import os |
6 | 5 | import os.path as path |
7 | 6 |
|
|
15 | 14 | 'test_labels': 't10k-labels-idx1-ubyte.gz'} |
16 | 15 |
|
17 | 16 |
|
18 | | -class MNISTDataset(cx.BaseDataset): |
| 17 | +class MNISTDataset(cx.DownloadableDataset): |
19 | 18 | """ MNIST dataset for hand-written digits recognition.""" |
20 | 19 |
|
21 | | - def _configure_dataset(self, data_root=path.join('datasets', '.mnist-data'), batch_size:int=100, **kwargs) -> None: |
| 20 | + def _configure_dataset(self, data_root=path.join('mnist_convnet', '.mnist-data'), batch_size:int=100, **kwargs) -> None: |
22 | 21 | self._batch_size = batch_size |
23 | 22 | self._data_root = data_root |
| 23 | + self._download_urls = [path.join(DOWNLOAD_ROOT, filename) for filename in FILENAMES.values()] |
24 | 24 | self._data = {} |
25 | 25 | self._data_loaded = False |
26 | 26 |
|
@@ -52,14 +52,3 @@ def test_stream(self) -> cx.Stream: |
52 | 52 | for i in range(0, len(self._data['test_labels']), self._batch_size): |
53 | 53 | yield {'images': self._data['test_images'][i: i + self._batch_size], |
54 | 54 | 'labels': self._data['test_labels'][i: i + self._batch_size]} |
55 | | - |
56 | | - def download(self) -> None: |
57 | | - """Download method may be invoked with `cxflow dataset download <path-to-config>`.""" |
58 | | - for part in FILENAMES.values(): |
59 | | - target = path.join(self._data_root, part) |
60 | | - if path.exists(target): |
61 | | - logging.info('\t%s already exists', target) |
62 | | - else: |
63 | | - os.makedirs(self._data_root, exist_ok=True) |
64 | | - logging.info('\tdownloading %s', target) |
65 | | - urllib.request.urlretrieve(DOWNLOAD_ROOT+part, target) |
0 commit comments