From cc7e856adcd1f9ede9de118b1693f68c5789deba Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 4 Jan 2022 11:49:43 +0100 Subject: [PATCH 1/2] add contribution instructions for prototype datasets (#5133) * add contribution instructions for prototype datasets * cleanup * fix links * Update torchvision/prototype/datasets/_builtin/README.md --- .../prototype/datasets/_builtin/README.md | 127 ++++++++++++++++++ .../prototype/datasets/_builtin/caltech.py | 10 +- .../prototype/datasets/_builtin/cifar.py | 5 +- .../prototype/datasets/_builtin/coco.py | 2 +- .../prototype/datasets/_builtin/imagenet.py | 3 +- .../prototype/datasets/_builtin/mnist.py | 20 ++- .../prototype/datasets/_builtin/sbd.py | 4 +- .../datasets/generate_category_files.py | 9 +- .../prototype/datasets/utils/_dataset.py | 4 +- .../prototype/datasets/utils/_resource.py | 4 +- 10 files changed, 169 insertions(+), 19 deletions(-) create mode 100644 torchvision/prototype/datasets/_builtin/README.md diff --git a/torchvision/prototype/datasets/_builtin/README.md b/torchvision/prototype/datasets/_builtin/README.md new file mode 100644 index 00000000000..20bcd7b89bb --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/README.md @@ -0,0 +1,127 @@ +# How to add new built-in prototype datasets + +As the name implies, the datasets are still in a prototype state and thus subject to rapid change. This in turn means that this document will also change a lot. + +If you hit a blocker while adding a dataset, please have a look at another similar dataset to see how it is implemented there. If you can't resolve it yourself, feel free to send a draft PR in order for us to help you out. + +Finally, `from torchvision.prototype import datasets` is implied below. + +## Implementation + +Before we start with the actual implementation, you should create a module in `torchvision/prototype/datasets/_builtin` that hints at the dataset you are going to add. For example `caltech.py` for `caltech101` and `caltech256`. In that module create a class that inherits from `datasets.utils.Dataset` and overwrites at minimum three methods that will be discussed in detail below: + +```python +import io +from typing import Any, Callable, Dict, List, Optional + +import torch +from torchdata.datapipes.iter import IterDataPipe +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource + +class MyDataset(Dataset): + def _make_info(self) -> DatasetInfo: + ... + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + ... + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> IterDataPipe[Dict[str, Any]]: + ... +``` + +### `_make_info(self)` + +The `DatasetInfo` carries static information about the dataset. There are two required fields: +- `name`: Name of the dataset. This will be used to load the dataset with `datasets.load(name)`. Should only contain lower characters. +- `type`: Field of the `datasets.utils.DatasetType` enum. This is used to select the default decoder in case the user doesn't pass one. There are currently only two options: `IMAGE` and `RAW` ([see below](what-is-the-datasettyperaw-and-when-do-i-use-it) for details). + +There are more optional parameters that can be passed: + +- `dependencies`: Collection of third-party dependencies that are needed to load the dataset, e.g. `("scipy",)`. Their availability will be automatically checked if a user tries to load the dataset. Within the implementation, import these packages lazily to avoid missing dependencies at import time. +- `categories`: Sequence of human-readable category names for each label. The index of each category has to match the corresponding label returned in the dataset samples. [See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories. +- `valid_options`: Configures valid options that can be passed to the dataset. It should be `Dict[str, Sequence[str]]`. The options are accessible through the `config` namespace in the other two functions. First value of the sequence is taken as default if the user passes no option to `torchvision.prototype.datasets.load()`. + +## `resources(self, config)` + +Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset with a specific `config` can be build. The download will happen automatically. + +Currently, the following `OnlineResource`'s are supported: + +- `HttpResource`: Used for files that are directly exposed through HTTP(s) and only requires the URL. +- `GDriveResource`: Used for files that are hosted on GDrive and requires the GDrive ID as well as the `file_name`. +- `ManualDownloadResource`: Used files are not publicly accessible and requires instructions how to download them manually. If the file does not exist, an error will be raised with the supplied instructions. + +Although optional in general, all resources used in the built-in datasets should comprise [SHA256](https://en.wikipedia.org/wiki/SHA-2) checksum for security. It will be automatically checked after the download. You can compute the checksum with system utilities or this snippet: + +```python +import hashlib + +def sha256sum(path, chunk_size=1024 * 1024): + checksum = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): + checksum.update(chunk) + print(checksum.hexdigest()) +``` + +### `_make_datapipe(resource_dps, *, config, decoder)` + +This method is the heart of the dataset that need to transform the raw data into a usable form. A major difference compared to the current stable datasets is that everything is performed through `IterDataPipe`'s. From the perspective of someone that is working with them rather than on them, `IterDataPipe`'s behave just as generators, i.e. you can't do anything with them besides iterating. + +Of course, there are some common building blocks that should suffice in 95% of the cases. The most used + +- `Mapper`: Apply a callable to every item in the datapipe. +- `Filter`: Keep only items that satisfy a condition. +- `Demultiplexer`: Split a datapipe into multiple ones. +- `IterKeyZipper`: Merge two datapipes into one. + +All of them can be imported `from torchdata.datapipes.iter`. In addition, use `functools.partial` in case a callable needs extra arguments. If the provided `IterDataPipe`'s are not sufficient for the use case, it is also not complicated to add one. See the MNIST or CelebA datasets for example. + +`make_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return value of `resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain tuples comprised of the path and the handle for every file in the archive. Otherwise the datapipe will only contain one of such tuples for the file specified by the resource. + +Since the datapipes are iterable in nature, some datapipes feature an in-memory buffer, e.g. `IterKeyZipper` and `Grouper`. There are two issues with that: +1. If not used carefully, this can easily overflow the host memory, since most datasets will not fit in completely. +2. This can lead to unnecessarily long warm-up times when data is buffered that is only needed at runtime. + +Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than trying to zip already loaded images. + +There are two special datapipes that are not used through their class, but through the functions `hint_sharding` and `hint_shuffling`. As the name implies they only hint part in the datapipe graph where sharding and shuffling should take place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal` and are required in each dataset. + +Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the names (yet!). + +## FAQ + +### What is the `DatasetType.RAW` and when do I use it? + +`DatasetType.RAW` marks dataset that provides decoded, i.e. raw pixel values, rather than encoded image files such as +`.jpg` or `.png`. This is usually only the case for small datasets, since it requires a lot more disk space. The default decoder `datasets.decoder.raw` is only a sentinel and should not be called directly. The decoding should look something like + +```python +from torchvision.prototype.datasets.decoder import raw + +image = ... + +if decoder is raw: + image = Image(image) +else: + image_buffer = image_buffer_from_raw(image) + image = decoder(image_buffer) if decoder else image_buffer +``` + +For examples, have a look at the MNIST, CIFAR, or SEMEION datasets. + +### How do I handle a dataset that defines many categories? + +As a rule of thumb, `datasets.utils.DatasetInfo(..., categories=)` should only be set directly for ten categories or fewer. If more categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line specifies a category. If `$NAME` matches the name of the dataset (which it definitively should!) it will be automatically loaded if `categories=` is not set. + +In case the categories can be generated from the dataset files, e.g. the dataset follow an image folder approach where each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. It gets passed the `root` path to the resources, but they have to be manually loaded, e.g. `self.resources(config)[0].load(root)`. The method should return a sequence of strings representing the category names. To generate the `$NAME.categories` file, run `python -m torchvision.prototype.datasets.generate_category_files $NAME`. + +### What if a resource file forms an I/O bottleneck? + +In general, we are ok with small performance hits of iterating archives rather than their extracted content. However, if the performance hit becomes significant, the archives can still be decompressed or extracted. To do this, the `decompress: bool` and `extract: bool` flags can be used for every `OnlineResource` individually. For more complex cases, each resource also accepts a `preprocess` callable that gets passed a `pathlib.Path` of the raw file and should return `pathlib.Path` of the preprocessed file or folder. diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index f4f8c44f8ee..be19b7c240f 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -136,8 +136,11 @@ def _make_datapipe( return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) def _generate_categories(self, root: pathlib.Path) -> List[str]: - dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) + resources = self.resources(self.default_config) + + dp = resources[0].load(root) dp = Filter(dp, self._is_not_background_image) + return sorted({pathlib.Path(path).parent.name for path, _ in dp}) @@ -189,6 +192,9 @@ def _make_datapipe( return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) def _generate_categories(self, root: pathlib.Path) -> List[str]: - dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) + resources = self.resources(self.default_config) + + dp = resources[0].load(root) dir_names = {pathlib.Path(path).parent.name for path, _ in dp} + return [name.split(".")[1] for name in sorted(dir_names)] diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 68147ba0f9e..6d0cd465982 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -92,9 +92,12 @@ def _make_datapipe( return Mapper(dp, functools.partial(self._collate_and_decode, decoder=decoder)) def _generate_categories(self, root: pathlib.Path) -> List[str]: - dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) + resources = self.resources(self.default_config) + + dp = resources[0].load(root) dp = Filter(dp, path_comparator("name", self._META_FILE_NAME)) dp = Mapper(dp, self._unpickle) + return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY]) diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index e400a1db07d..6fde966402c 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -238,7 +238,7 @@ def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]: config = self.default_config resources = self.resources(config) - dp = resources[1].load(pathlib.Path(root) / self.name) + dp = resources[1].load(root) dp = Filter( dp, functools.partial(self._filter_meta_files, split=config.split, year=config.year, annotations="instances"), diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 9ea70296427..ac3649c8839 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -177,7 +177,8 @@ def _make_datapipe( def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]: resources = self.resources(self.default_config) - devkit_dp = resources[1].load(root / self.name) + + devkit_dp = resources[1].load(root) devkit_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) meta = next(iter(devkit_dp))[1] diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 8f49f1ce72a..0d7fe36a3fd 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -4,7 +4,7 @@ import operator import pathlib import string -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, Union, Sequence import torch from torchdata.datapipes.iter import ( @@ -78,7 +78,7 @@ def __iter__(self) -> Iterator[torch.Tensor]: class _MNISTBase(Dataset): - _URL_BASE: str + _URL_BASE: Union[str, Sequence[str]] @abc.abstractmethod def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: @@ -90,8 +90,15 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: labels_sha256, ) = self._files_and_checksums(config) - images = HttpResource(f"{self._URL_BASE}/{images_file}", sha256=images_sha256) - labels = HttpResource(f"{self._URL_BASE}/{labels_file}", sha256=labels_sha256) + url_bases = self._URL_BASE + if isinstance(url_bases, str): + url_bases = (url_bases,) + + images_urls = [f"{url_base}/{images_file}" for url_base in url_bases] + images = HttpResource(images_urls[0], sha256=images_sha256, mirrors=images_urls[1:]) + + labels_urls = [f"{url_base}/{labels_file}" for url_base in url_bases] + labels = HttpResource(labels_urls[0], sha256=images_sha256, mirrors=labels_urls[1:]) return [images, labels] @@ -151,7 +158,10 @@ def _make_info(self) -> DatasetInfo: ), ) - _URL_BASE = "http://yann.lecun.com/exdb/mnist" + _URL_BASE: Union[str, Sequence[str]] = ( + "http://yann.lecun.com/exdb/mnist", + "https://ossci-datasets.s3.amazonaws.com/mnist/", + ) _CHECKSUMS = { "train-images-idx3-ubyte.gz": "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609", "train-labels-idx1-ubyte.gz": "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c", diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index 82fdb2adf8b..f605d7d72f1 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -156,7 +156,9 @@ def _make_datapipe( return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder)) def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: - dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) + resources = self.resources(self.default_config) + + dp = resources[0].load(root) dp = Filter(dp, path_comparator("name", "category_names.m")) dp = LineReader(dp) dp = Mapper(dp, bytes.decode, input_col=1) diff --git a/torchvision/prototype/datasets/generate_category_files.py b/torchvision/prototype/datasets/generate_category_files.py index 3ce908f79ef..40843ecf50b 100644 --- a/torchvision/prototype/datasets/generate_category_files.py +++ b/torchvision/prototype/datasets/generate_category_files.py @@ -11,7 +11,7 @@ def main(*names, force=False): - root = pathlib.Path(datasets.home()) + home = pathlib.Path(datasets.home()) for name in names: path = BUILTIN_DIR / f"{name}.categories" @@ -20,13 +20,14 @@ def main(*names, force=False): dataset = find(name) try: - categories = dataset._generate_categories(root) + categories = dataset._generate_categories(home / name) except NotImplementedError: continue - with open(path, "w", newline="") as file: + with open(path, "w") as file: + writer = csv.writer(file, lineterminator="\n") for category in categories: - csv.writer(file).writerow((category,) if isinstance(category, str) else category) + writer.writerow((category,) if isinstance(category, str) else category) def parse_args(argv=None): diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 242f9c961c0..6cac7dcd093 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -6,7 +6,7 @@ import itertools import os import pathlib -from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple, Collection import torch from torch.utils.data import IterDataPipe @@ -33,7 +33,7 @@ def __init__( name: str, *, type: Union[str, DatasetType], - dependencies: Sequence[str] = (), + dependencies: Collection[str] = (), categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, citation: Optional[str] = None, homepage: Optional[str] = None, diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 94603bfc81e..cf30c5ff302 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -136,14 +136,14 @@ def _check_sha256(self, path: pathlib.Path, *, chunk_size: int = 1024 * 1024) -> class HttpResource(OnlineResource): def __init__( - self, url: str, *, file_name: Optional[str] = None, mirrors: Optional[Sequence[str]] = None, **kwargs: Any + self, url: str, *, file_name: Optional[str] = None, mirrors: Sequence[str] = (), **kwargs: Any ) -> None: super().__init__(file_name=file_name or pathlib.Path(urlparse(url).path).name, **kwargs) self.url = url self.mirrors = mirrors def _download(self, root: pathlib.Path) -> None: - for url in itertools.chain((self.url,), self.mirrors or ()): + for url in itertools.chain((self.url,), self.mirrors): try: download_url(url, str(root), filename=self.file_name, md5=None) # TODO: make this more precise From bbeb32035c0df752a0fe0cd5921de537e7f68d72 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 4 Jan 2022 11:38:28 +0000 Subject: [PATCH 2/2] Remove incorrect ViT recipe commands. (#5159) --- references/classification/README.md | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/references/classification/README.md b/references/classification/README.md index ff5371066d2..a73fde3679f 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -143,28 +143,6 @@ torchrun --nproc_per_node=8 train.py\ ``` Here `$MODEL` is one of `regnet_x_32gf`, `regnet_y_16gf` and `regnet_y_32gf`. -### Vision Transformer - -#### Base models -``` -torchrun --nproc_per_node=8 train.py\ - --model $MODEL --epochs 300 --batch-size 64 --opt adamw --lr 0.003 --wd 0.3\ - --lr-scheduler cosineannealinglr --lr-warmup-method linear --lr-warmup-epochs 30\ - --lr-warmup-decay 0.033 --amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment ra\ - --clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema -``` -Here `$MODEL` is one of `vit_b_16` and `vit_b_32`. - -#### Large models -``` -torchrun --nproc_per_node=8 train.py\ - --model $MODEL --epochs 300 --batch-size 16 --opt adamw --lr 0.003 --wd 0.3\ - --lr-scheduler cosineannealinglr --lr-warmup-method linear --lr-warmup-epochs 30\ - --lr-warmup-decay 0.033 --amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment ra\ - --clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema -``` -Here `$MODEL` is one of `vit_l_16` and `vit_l_32`. - ## Mixed precision training Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp).