From 0599d33ee468733afce5375ec0e2563286446769 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Tue, 5 Apr 2022 01:11:59 -0700 Subject: [PATCH] [fbsync] USPS dataset (#5647) Summary: * added usps dataset * fixed type issues * fix mobilnet norm layer test (#5643) * xfail mobilnet norm layer test * fix test * More robust check in tests for 16 bits images (#5652) * Prefer nvidia channel for conda builds (#5648) To mitigate missing `libcupti.so` dependency * fix torchdata CI installation (#5657) * update urls for kinetics dataset (#5578) * update urls for kinetics dataset * update urls for kinetics dataset * remove errors * update the changes and add test option to split * added test to valid values for split arg * change .txt to .csv for annotation url of k600 * Port Multi-weight support from prototype to main (#5618) * Moving basefiles outside of prototype and porting Alexnet, ConvNext, Densenet and EfficientNet. * Porting googlenet * Porting inception * Porting mnasnet * Porting mobilenetv2 * Porting mobilenetv3 * Porting regnet * Porting resnet * Porting shufflenetv2 * Porting squeezenet * Porting vgg * Porting vit * Fix docstrings * Fixing imports * Adding missing import * Fix mobilenet imports * Fix tests * Fix prototype tests * Exclude get_weight from models on test * Fix init files * Porting googlenet * Porting inception * porting mobilenetv2 * porting mobilenetv3 * porting resnet * porting shufflenetv2 * Fix test and linter * Fixing docs. * Porting Detection models (#5617) * fix inits * fix docs * Port faster_rcnn * Port fcos * Port keypoint_rcnn * Port mask_rcnn * Port retinanet * Port ssd * Port ssdlite * Fix linter * Fixing tests * Fixing tests * Fixing vgg test * Porting Optical Flow, Segmentation, Video models (#5619) * Porting raft * Porting video resnet * Porting deeplabv3 * Porting fcn and lraspp * Fixing the tests and linter * Porting docs, examples, tutorials and galleries (#5620) * Fix examples, tutorials and gallery * Update gallery/plot_optical_flow.py * Fix import * Revert hardcoded normalization * fix uncommitted changes * Fix bug * Fix more bugs * Making resize optional for segmentation * Fixing preset * Fix mypy * Fixing documentation strings * Fix flake8 * minor refactoring * Resolve conflict * Porting model tests (#5622) * Porting tests * Remove unnecessary variable * Fix linter * Move prototype to extended tests * Fix download models job * Update CI on Multiweight branch to use the new weight download approach (#5628) * port Pad to prototype transforms (#5621) * port Pad to prototype transforms * use literal * Bump up LibTorchvision version number for Podspec to release Cocoapods (#5624) * pre-download model weights in CI docs build (#5625) * pre-download model weights in CI docs build * move changes into template * change docs image * Regenerated config.yml * Porting reference scripts and updating presets (#5629) * Making _preset.py classes * Remove support of targets on presets. * Rewriting the video preset * Adding tests to check that the bundled transforms are JIT scriptable * Rename all presets from *Eval to *Inference * Minor refactoring * Remove --prototype and --pretrained from reference scripts * remove pretained_backbone refs * Corrections and simplifications * Fixing bug * Fixing linter * Fix flake8 * restore documentation example * minor fixes * fix optical flow missing param * Fixing commands * Adding weights_backbone support in detection and segmentation * Updating the commands for InceptionV3 * Setting `weights_backbone` to its fully BC value (#5653) * Replace default `weights_backbone=None` with its BC values. * Fixing tests * Fix linter * Update docs. * Update preprocessing on reference scripts. * Change qat/ptq to their full values. * Refactoring preprocessing * Fix video preset * No initialization on VGG if pretrained * Fix warning messages for backbone utils. * Adding star to all preset constructors. * Fix mypy. * Apply suggestions from code review * use decompressor for extracting bz2 * Apply suggestions from code review * Apply suggestions from code review * fixed lint fails * added tests for USPS * check image shape * fix tests * check shape on image directly * Apply suggestions from code review * removed test and comments * Update test/test_prototype_builtin_datasets.py (Note: this ignores all push blocking failures!) Reviewed By: datumbox Differential Revision: D35216783 fbshipit-source-id: 556a63a89f15d1541ac2b479244a7b6c564eff14 Co-authored-by: Nicolas Hug Co-authored-by: Nicolas Hug Co-authored-by: Nicolas Hug Co-authored-by: Anton Thomma Co-authored-by: Vasilis Vryniotis Co-authored-by: Philip Meier Co-authored-by: Anton Thomma <11010310+thommaa@users.noreply.github.com> Co-authored-by: Anton Thomma Co-authored-by: Nicolas Hug Co-authored-by: Philip Meier Co-authored-by: Anton Thomma <11010310+thommaa@users.noreply.github.com> Co-authored-by: Anton Thomma Co-authored-by: Philip Meier Co-authored-by: Philip Meier Co-authored-by: Philip Meier Co-authored-by: Philip Meier Co-authored-by: Nicolas Hug Co-authored-by: Philip Meier Co-authored-by: Nicolas Hug Co-authored-by: Nikita Shulga Co-authored-by: Sahil Goyal Co-authored-by: Vasilis Vryniotis Co-authored-by: Anton Thomma <11010310+thommaa@users.noreply.github.com> Co-authored-by: Anton Thomma --- test/builtin_dataset_mocks.py | 19 +++++++ test/test_prototype_builtin_datasets.py | 19 ++++++- .../prototype/datasets/_builtin/__init__.py | 1 + .../prototype/datasets/_builtin/usps.py | 54 +++++++++++++++++++ 4 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 torchvision/prototype/datasets/_builtin/usps.py diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 62259a604a0..1153c1b33f0 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1,3 +1,4 @@ +import bz2 import collections.abc import csv import functools @@ -1431,3 +1432,21 @@ def stanford_cars(info, root, config): make_tar(root, "car_devkit.tgz", devkit, compression="gz") return num_samples + + +@register_mock +def usps(info, root, config): + num_samples = {"train": 15, "test": 7}[config.split] + + with bz2.open(root / f"usps{'.t' if not config.split == 'train' else ''}.bz2", "wb") as fh: + lines = [] + for _ in range(num_samples): + label = make_tensor(1, low=1, high=11, dtype=torch.int) + values = make_tensor(256, low=-1, high=1, dtype=torch.float) + lines.append( + " ".join([f"{int(label)}", *(f"{idx}:{float(value):.6f}" for idx, value in enumerate(values, 1))]) + ) + + fh.write("\n".join(lines).encode()) + + return num_samples diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index f7c40d432a4..f414f4e48cd 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -12,7 +12,7 @@ from torchdata.datapipes.iter import IterDataPipe, Shuffler from torchvision._utils import sequence_to_str from torchvision.prototype import transforms, datasets - +from torchvision.prototype.features import Image, Label assert_samples_equal = functools.partial( assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True @@ -180,3 +180,20 @@ def test_label_matches_path(self, test_home, dataset_mock, config): for sample in dataset: label_from_path = int(Path(sample["path"]).parent.name) assert sample["label"] == label_from_path + + +@parametrize_dataset_mocks(DATASET_MOCKS["usps"]) +class TestUSPS: + def test_sample_content(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) + + dataset = datasets.load(dataset_mock.name, **config) + + for sample in dataset: + assert "image" in sample + assert "label" in sample + + assert isinstance(sample["image"], Image) + assert isinstance(sample["label"], Label) + + assert sample["image"].shape == (1, 16, 16) diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index feb558aa03f..1a8dc0907a4 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -17,4 +17,5 @@ from .semeion import SEMEION from .stanford_cars import StanfordCars from .svhn import SVHN +from .usps import USPS from .voc import VOC diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py new file mode 100644 index 00000000000..5df0978d031 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -0,0 +1,54 @@ +from typing import Any, Dict, List + +import torch +from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling +from torchvision.prototype.features import Image, Label + + +class USPS(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "usps", + homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps", + valid_options=dict( + split=("train", "test"), + ), + categories=10, + ) + + _URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass" + + _RESOURCES = { + "train": HttpResource( + f"{_URL}/usps.bz2", sha256="3771e9dd6ba685185f89867b6e249233dd74652389f263963b3b741e994b034f" + ), + "test": HttpResource( + f"{_URL}/usps.t.bz2", sha256="a9c0164e797d60142a50604917f0baa604f326e9a689698763793fa5d12ffc4e" + ), + } + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + return [USPS._RESOURCES[config.split]] + + def _prepare_sample(self, line: str) -> Dict[str, Any]: + label, *values = line.strip().split(" ") + values = [float(value.split(":")[1]) for value in values] + pixels = torch.tensor(values).add_(1).div_(2) + return dict( + image=Image(pixels.reshape(16, 16)), + label=Label(int(label) - 1, categories=self.categories), + ) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + ) -> IterDataPipe[Dict[str, Any]]: + dp = Decompressor(resource_dps[0]) + dp = LineReader(dp, decode=True, return_path=False) + dp = hint_sharding(dp) + dp = hint_shuffling(dp) + return Mapper(dp, self._prepare_sample)