From 31e503f13f54bdf631dfd9038b3019d60ba3a9f4 Mon Sep 17 00:00:00 2001 From: Yassine Alouini Date: Mon, 4 Apr 2022 11:46:39 +0200 Subject: [PATCH 1/3] Food101 new dataset api (#5584) * [FEAT] Start implementing Food101 using the new datasets API. WIP. * [FEAT] Generate Food101 categories and start the test mock. * [FEAT] food101 dataset code seems to work now. * [TEST] food101 mock update. * [FIX] Some fixes thanks to running food101 tests. * [FIX] Fix mypy checks for the food101 file. * [FIX] Remove unused numpy. * [FIX] Some changes thanks to code review. * [ENH] More idomatic dataset code thanks to code review. * [FIX] Remove unused cast. * [ENH] Set decompress and extract to True for some performance gains. * [FEAT] Use the preprocess=decompress keyword. * [ENH] Use the train and test.txt file instead of the .json variants and simplify code + update mock data. * [ENH] Better food101 mock data generation. * [FIX] Remove a useless print. Co-authored-by: Philip Meier --- test/builtin_dataset_mocks.py | 38 +++++++ .../prototype/datasets/_builtin/__init__.py | 1 + .../datasets/_builtin/food101.categories | 101 ++++++++++++++++++ .../prototype/datasets/_builtin/food101.py | 91 ++++++++++++++++ 4 files changed, 231 insertions(+) create mode 100644 torchvision/prototype/datasets/_builtin/food101.categories create mode 100644 torchvision/prototype/datasets/_builtin/food101.py diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 768177b1c28..c5608377d97 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -911,6 +911,44 @@ def country211(info, root, config): return num_examples * len(classes) +@register_mock +def food101(info, root, config): + data_folder = root / "food-101" + + num_images_per_class = 3 + image_folder = data_folder / "images" + categories = ["apple_pie", "baby_back_ribs", "waffles"] + image_ids = [] + for category in categories: + image_files = create_image_folder( + image_folder, + category, + file_name_fn=lambda idx: f"{idx:04d}.jpg", + num_examples=num_images_per_class, + ) + image_ids.extend(path.relative_to(path.parents[1]).with_suffix("").as_posix() for path in image_files) + + meta_folder = data_folder / "meta" + meta_folder.mkdir() + + with open(meta_folder / "classes.txt", "w") as file: + for category in categories: + file.write(f"{category}\n") + + splits = ["train", "test"] + num_samples_map = {} + for offset, split in enumerate(splits): + image_ids_in_split = image_ids[offset :: len(splits)] + num_samples_map[split] = len(image_ids_in_split) + with open(meta_folder / f"{split}.txt", "w") as file: + for image_id in image_ids_in_split: + file.write(f"{image_id}\n") + + make_tar(root, f"{data_folder.name}.tar.gz", compression="gz") + + return num_samples_map[config.split] + + @register_mock def dtd(info, root, config): data_folder = root / "dtd" diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 1a8dc0907a4..b2beddc7f2b 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -8,6 +8,7 @@ from .dtd import DTD from .eurosat import EuroSAT from .fer2013 import FER2013 +from .food101 import Food101 from .gtsrb import GTSRB from .imagenet import ImageNet from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST diff --git a/torchvision/prototype/datasets/_builtin/food101.categories b/torchvision/prototype/datasets/_builtin/food101.categories new file mode 100644 index 00000000000..59f252ddff4 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/food101.categories @@ -0,0 +1,101 @@ +apple_pie +baby_back_ribs +baklava +beef_carpaccio +beef_tartare +beet_salad +beignets +bibimbap +bread_pudding +breakfast_burrito +bruschetta +caesar_salad +cannoli +caprese_salad +carrot_cake +ceviche +cheesecake +cheese_plate +chicken_curry +chicken_quesadilla +chicken_wings +chocolate_cake +chocolate_mousse +churros +clam_chowder +club_sandwich +crab_cakes +creme_brulee +croque_madame +cup_cakes +deviled_eggs +donuts +dumplings +edamame +eggs_benedict +escargots +falafel +filet_mignon +fish_and_chips +foie_gras +french_fries +french_onion_soup +french_toast +fried_calamari +fried_rice +frozen_yogurt +garlic_bread +gnocchi +greek_salad +grilled_cheese_sandwich +grilled_salmon +guacamole +gyoza +hamburger +hot_and_sour_soup +hot_dog +huevos_rancheros +hummus +ice_cream +lasagna +lobster_bisque +lobster_roll_sandwich +macaroni_and_cheese +macarons +miso_soup +mussels +nachos +omelette +onion_rings +oysters +pad_thai +paella +pancakes +panna_cotta +peking_duck +pho +pizza +pork_chop +poutine +prime_rib +pulled_pork_sandwich +ramen +ravioli +red_velvet_cake +risotto +samosa +sashimi +scallops +seaweed_salad +shrimp_and_grits +spaghetti_bolognese +spaghetti_carbonara +spring_rolls +steak +strawberry_shortcake +sushi +tacos +takoyaki +tiramisu +tuna_tartare +waffles diff --git a/torchvision/prototype/datasets/_builtin/food101.py b/torchvision/prototype/datasets/_builtin/food101.py new file mode 100644 index 00000000000..cb720f137d9 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/food101.py @@ -0,0 +1,91 @@ +from pathlib import Path +from typing import Any, Tuple, List, Dict, Optional, BinaryIO + +from torchdata.datapipes.iter import ( + IterDataPipe, + Filter, + Mapper, + LineReader, + Demultiplexer, + IterKeyZipper, +) +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import ( + hint_shuffling, + hint_sharding, + path_comparator, + getitem, + INFINITE_BUFFER_SIZE, +) +from torchvision.prototype.features import Label, EncodedImage + + +class Food101(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "food101", + homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101", + valid_options=dict(split=("train", "test")), + ) + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + return [ + HttpResource( + url="http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz", + sha256="d97d15e438b7f4498f96086a4f7e2fa42a32f2712e87d3295441b2b6314053a4", + preprocess="decompress", + ) + ] + + def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: + path = Path(data[0]) + if path.parents[1].name == "images": + return 0 + elif path.parents[0].name == "meta": + return 1 + else: + return None + + def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]: + id, (path, buffer) = data + return dict( + label=Label.from_category(id.split("/", 1)[0], categories=self.categories), + path=path, + image=EncodedImage.from_file(buffer), + ) + + def _image_key(self, data: Tuple[str, Any]) -> str: + path = Path(data[0]) + return path.relative_to(path.parents[1]).with_suffix("").as_posix() + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + ) -> IterDataPipe[Dict[str, Any]]: + archive_dp = resource_dps[0] + images_dp, split_dp = Demultiplexer( + archive_dp, 2, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE + ) + split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) + split_dp = LineReader(split_dp, decode=True, return_path=False) + split_dp = hint_sharding(split_dp) + split_dp = hint_shuffling(split_dp) + + dp = IterKeyZipper( + split_dp, + images_dp, + key_fn=getitem(), + ref_key_fn=self._image_key, + buffer_size=INFINITE_BUFFER_SIZE, + ) + + return Mapper(dp, self._prepare_sample) + + def _generate_categories(self, root: Path) -> List[str]: + resources = self.resources(self.default_config) + dp = resources[0].load(root) + dp = Filter(dp, path_comparator("name", "classes.txt")) + dp = LineReader(dp, decode=True, return_path=False) + return list(dp) From 890450a457895829d75befbf1ff627cb14a95763 Mon Sep 17 00:00:00 2001 From: KyleCZH <70175284+KyleCZH@users.noreply.github.com> Date: Mon, 4 Apr 2022 04:43:12 -0700 Subject: [PATCH 2/3] [ROCm] Update to rocm5.0 wheels (#5660) Signed-off-by: Kyle Chen Co-authored-by: Vasilis Vryniotis --- .circleci/config.yml | 136 ++++++++++++++++++++-------------------- .circleci/regenerate.py | 2 +- 2 files changed, 69 insertions(+), 69 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 2a7c679e021..df93de4692e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1141,16 +1141,16 @@ workflows: name: binary_linux_wheel_py3.7_cu115 python_version: '3.7' wheel_docker_image: pytorch/manylinux-cuda115 - - binary_linux_wheel: - cu_version: rocm4.3.1 - name: binary_linux_wheel_py3.7_rocm4.3.1 - python_version: '3.7' - wheel_docker_image: pytorch/manylinux-rocm:4.3.1 - binary_linux_wheel: cu_version: rocm4.5.2 name: binary_linux_wheel_py3.7_rocm4.5.2 python_version: '3.7' wheel_docker_image: pytorch/manylinux-rocm:4.5.2 + - binary_linux_wheel: + cu_version: rocm5.0 + name: binary_linux_wheel_py3.7_rocm5.0 + python_version: '3.7' + wheel_docker_image: pytorch/manylinux-rocm:5.0 - binary_linux_wheel: conda_docker_image: pytorch/conda-builder:cpu cu_version: cpu @@ -1175,16 +1175,16 @@ workflows: name: binary_linux_wheel_py3.8_cu115 python_version: '3.8' wheel_docker_image: pytorch/manylinux-cuda115 - - binary_linux_wheel: - cu_version: rocm4.3.1 - name: binary_linux_wheel_py3.8_rocm4.3.1 - python_version: '3.8' - wheel_docker_image: pytorch/manylinux-rocm:4.3.1 - binary_linux_wheel: cu_version: rocm4.5.2 name: binary_linux_wheel_py3.8_rocm4.5.2 python_version: '3.8' wheel_docker_image: pytorch/manylinux-rocm:4.5.2 + - binary_linux_wheel: + cu_version: rocm5.0 + name: binary_linux_wheel_py3.8_rocm5.0 + python_version: '3.8' + wheel_docker_image: pytorch/manylinux-rocm:5.0 - binary_linux_wheel: conda_docker_image: pytorch/conda-builder:cpu cu_version: cpu @@ -1209,16 +1209,16 @@ workflows: name: binary_linux_wheel_py3.9_cu115 python_version: '3.9' wheel_docker_image: pytorch/manylinux-cuda115 - - binary_linux_wheel: - cu_version: rocm4.3.1 - name: binary_linux_wheel_py3.9_rocm4.3.1 - python_version: '3.9' - wheel_docker_image: pytorch/manylinux-rocm:4.3.1 - binary_linux_wheel: cu_version: rocm4.5.2 name: binary_linux_wheel_py3.9_rocm4.5.2 python_version: '3.9' wheel_docker_image: pytorch/manylinux-rocm:4.5.2 + - binary_linux_wheel: + cu_version: rocm5.0 + name: binary_linux_wheel_py3.9_rocm5.0 + python_version: '3.9' + wheel_docker_image: pytorch/manylinux-rocm:5.0 - binary_linux_wheel: conda_docker_image: pytorch/conda-builder:cpu cu_version: cpu @@ -1243,16 +1243,16 @@ workflows: name: binary_linux_wheel_py3.10_cu115 python_version: '3.10' wheel_docker_image: pytorch/manylinux-cuda115 - - binary_linux_wheel: - cu_version: rocm4.3.1 - name: binary_linux_wheel_py3.10_rocm4.3.1 - python_version: '3.10' - wheel_docker_image: pytorch/manylinux-rocm:4.3.1 - binary_linux_wheel: cu_version: rocm4.5.2 name: binary_linux_wheel_py3.10_rocm4.5.2 python_version: '3.10' wheel_docker_image: pytorch/manylinux-rocm:4.5.2 + - binary_linux_wheel: + cu_version: rocm5.0 + name: binary_linux_wheel_py3.10_rocm5.0 + python_version: '3.10' + wheel_docker_image: pytorch/manylinux-rocm:5.0 - binary_macos_wheel: conda_docker_image: pytorch/conda-builder:cpu cu_version: cpu @@ -1898,15 +1898,15 @@ workflows: - nightly_binary_linux_wheel_py3.7_cu115 subfolder: cu115/ - binary_linux_wheel: - cu_version: rocm4.3.1 + cu_version: rocm4.5.2 filters: branches: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.7_rocm4.3.1 + name: nightly_binary_linux_wheel_py3.7_rocm4.5.2 python_version: '3.7' - wheel_docker_image: pytorch/manylinux-rocm:4.3.1 + wheel_docker_image: pytorch/manylinux-rocm:4.5.2 - binary_wheel_upload: context: org-member filters: @@ -1914,20 +1914,20 @@ workflows: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.7_rocm4.3.1_upload + name: nightly_binary_linux_wheel_py3.7_rocm4.5.2_upload requires: - - nightly_binary_linux_wheel_py3.7_rocm4.3.1 - subfolder: rocm4.3.1/ + - nightly_binary_linux_wheel_py3.7_rocm4.5.2 + subfolder: rocm4.5.2/ - binary_linux_wheel: - cu_version: rocm4.5.2 + cu_version: rocm5.0 filters: branches: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.7_rocm4.5.2 + name: nightly_binary_linux_wheel_py3.7_rocm5.0 python_version: '3.7' - wheel_docker_image: pytorch/manylinux-rocm:4.5.2 + wheel_docker_image: pytorch/manylinux-rocm:5.0 - binary_wheel_upload: context: org-member filters: @@ -1935,10 +1935,10 @@ workflows: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.7_rocm4.5.2_upload + name: nightly_binary_linux_wheel_py3.7_rocm5.0_upload requires: - - nightly_binary_linux_wheel_py3.7_rocm4.5.2 - subfolder: rocm4.5.2/ + - nightly_binary_linux_wheel_py3.7_rocm5.0 + subfolder: rocm5.0/ - binary_linux_wheel: conda_docker_image: pytorch/conda-builder:cpu cu_version: cpu @@ -2028,15 +2028,15 @@ workflows: - nightly_binary_linux_wheel_py3.8_cu115 subfolder: cu115/ - binary_linux_wheel: - cu_version: rocm4.3.1 + cu_version: rocm4.5.2 filters: branches: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.8_rocm4.3.1 + name: nightly_binary_linux_wheel_py3.8_rocm4.5.2 python_version: '3.8' - wheel_docker_image: pytorch/manylinux-rocm:4.3.1 + wheel_docker_image: pytorch/manylinux-rocm:4.5.2 - binary_wheel_upload: context: org-member filters: @@ -2044,20 +2044,20 @@ workflows: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.8_rocm4.3.1_upload + name: nightly_binary_linux_wheel_py3.8_rocm4.5.2_upload requires: - - nightly_binary_linux_wheel_py3.8_rocm4.3.1 - subfolder: rocm4.3.1/ + - nightly_binary_linux_wheel_py3.8_rocm4.5.2 + subfolder: rocm4.5.2/ - binary_linux_wheel: - cu_version: rocm4.5.2 + cu_version: rocm5.0 filters: branches: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.8_rocm4.5.2 + name: nightly_binary_linux_wheel_py3.8_rocm5.0 python_version: '3.8' - wheel_docker_image: pytorch/manylinux-rocm:4.5.2 + wheel_docker_image: pytorch/manylinux-rocm:5.0 - binary_wheel_upload: context: org-member filters: @@ -2065,10 +2065,10 @@ workflows: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.8_rocm4.5.2_upload + name: nightly_binary_linux_wheel_py3.8_rocm5.0_upload requires: - - nightly_binary_linux_wheel_py3.8_rocm4.5.2 - subfolder: rocm4.5.2/ + - nightly_binary_linux_wheel_py3.8_rocm5.0 + subfolder: rocm5.0/ - binary_linux_wheel: conda_docker_image: pytorch/conda-builder:cpu cu_version: cpu @@ -2158,15 +2158,15 @@ workflows: - nightly_binary_linux_wheel_py3.9_cu115 subfolder: cu115/ - binary_linux_wheel: - cu_version: rocm4.3.1 + cu_version: rocm4.5.2 filters: branches: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.9_rocm4.3.1 + name: nightly_binary_linux_wheel_py3.9_rocm4.5.2 python_version: '3.9' - wheel_docker_image: pytorch/manylinux-rocm:4.3.1 + wheel_docker_image: pytorch/manylinux-rocm:4.5.2 - binary_wheel_upload: context: org-member filters: @@ -2174,20 +2174,20 @@ workflows: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.9_rocm4.3.1_upload + name: nightly_binary_linux_wheel_py3.9_rocm4.5.2_upload requires: - - nightly_binary_linux_wheel_py3.9_rocm4.3.1 - subfolder: rocm4.3.1/ + - nightly_binary_linux_wheel_py3.9_rocm4.5.2 + subfolder: rocm4.5.2/ - binary_linux_wheel: - cu_version: rocm4.5.2 + cu_version: rocm5.0 filters: branches: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.9_rocm4.5.2 + name: nightly_binary_linux_wheel_py3.9_rocm5.0 python_version: '3.9' - wheel_docker_image: pytorch/manylinux-rocm:4.5.2 + wheel_docker_image: pytorch/manylinux-rocm:5.0 - binary_wheel_upload: context: org-member filters: @@ -2195,10 +2195,10 @@ workflows: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.9_rocm4.5.2_upload + name: nightly_binary_linux_wheel_py3.9_rocm5.0_upload requires: - - nightly_binary_linux_wheel_py3.9_rocm4.5.2 - subfolder: rocm4.5.2/ + - nightly_binary_linux_wheel_py3.9_rocm5.0 + subfolder: rocm5.0/ - binary_linux_wheel: conda_docker_image: pytorch/conda-builder:cpu cu_version: cpu @@ -2288,15 +2288,15 @@ workflows: - nightly_binary_linux_wheel_py3.10_cu115 subfolder: cu115/ - binary_linux_wheel: - cu_version: rocm4.3.1 + cu_version: rocm4.5.2 filters: branches: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.10_rocm4.3.1 + name: nightly_binary_linux_wheel_py3.10_rocm4.5.2 python_version: '3.10' - wheel_docker_image: pytorch/manylinux-rocm:4.3.1 + wheel_docker_image: pytorch/manylinux-rocm:4.5.2 - binary_wheel_upload: context: org-member filters: @@ -2304,20 +2304,20 @@ workflows: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.10_rocm4.3.1_upload + name: nightly_binary_linux_wheel_py3.10_rocm4.5.2_upload requires: - - nightly_binary_linux_wheel_py3.10_rocm4.3.1 - subfolder: rocm4.3.1/ + - nightly_binary_linux_wheel_py3.10_rocm4.5.2 + subfolder: rocm4.5.2/ - binary_linux_wheel: - cu_version: rocm4.5.2 + cu_version: rocm5.0 filters: branches: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.10_rocm4.5.2 + name: nightly_binary_linux_wheel_py3.10_rocm5.0 python_version: '3.10' - wheel_docker_image: pytorch/manylinux-rocm:4.5.2 + wheel_docker_image: pytorch/manylinux-rocm:5.0 - binary_wheel_upload: context: org-member filters: @@ -2325,10 +2325,10 @@ workflows: only: nightly tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - name: nightly_binary_linux_wheel_py3.10_rocm4.5.2_upload + name: nightly_binary_linux_wheel_py3.10_rocm5.0_upload requires: - - nightly_binary_linux_wheel_py3.10_rocm4.5.2 - subfolder: rocm4.5.2/ + - nightly_binary_linux_wheel_py3.10_rocm5.0 + subfolder: rocm5.0/ - binary_macos_wheel: conda_docker_image: pytorch/conda-builder:cpu cu_version: cpu diff --git a/.circleci/regenerate.py b/.circleci/regenerate.py index 04267907ad9..52f632fcb2a 100755 --- a/.circleci/regenerate.py +++ b/.circleci/regenerate.py @@ -32,7 +32,7 @@ def build_workflows(prefix="", filter_branch=None, upload=False, indentation=6, for os_type in ["linux", "macos", "win"]: python_versions = PYTHON_VERSIONS cu_versions_dict = { - "linux": ["cpu", "cu102", "cu113", "cu115", "rocm4.3.1", "rocm4.5.2"], + "linux": ["cpu", "cu102", "cu113", "cu115", "rocm4.5.2", "rocm5.0"], "win": ["cpu", "cu113", "cu115"], "macos": ["cpu"], } From 3130b457934124ffc7e9bdb6b2d86efa9a8c71cf Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 4 Apr 2022 14:15:10 +0200 Subject: [PATCH 3/3] [proto] Added functional `rotate_segmentation_mask` op (#5692) * Added functional affine_bounding_box op with tests * Updated comments and added another test case * Update _geometry.py * Added affine_segmentation_mask with tests * Fixed device mismatch issue Added a cude/cpu test Reduced the number of test samples * Added test_correctness_affine_segmentation_mask_on_fixed_input * Updates according to the review * Replaced [None, ...] by [None, :] * Adressed review comments * Fixed formatting and more updates according to the review * Fixed bad merge * WIP * Fixed tests * Updated warning message --- test/test_prototype_transforms_functional.py | 145 ++++++++++++++++-- .../transforms/functional/__init__.py | 1 + .../transforms/functional/_geometry.py | 23 ++- 3 files changed, 152 insertions(+), 17 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 6f10945feaf..3876beea5c4 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -266,7 +266,7 @@ def affine_bounding_box(): @register_kernel_info_from_sample_inputs_fn def affine_segmentation_mask(): - for image, angle, translate, scale, shear in itertools.product( + for mask, angle, translate, scale, shear in itertools.product( make_segmentation_masks(extra_dims=((), (4,))), [-87, 15, 90], # angle [5, -5], # translate @@ -274,7 +274,7 @@ def affine_segmentation_mask(): [0, 12], # shear ): yield SampleInput( - image, + mask, angle=angle, translate=(translate, translate), scale=scale, @@ -285,8 +285,12 @@ def affine_segmentation_mask(): @register_kernel_info_from_sample_inputs_fn def rotate_bounding_box(): for bounding_box, angle, expand, center in itertools.product( - make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]] # angle # expand # center + make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]] ): + if center is not None and expand: + # Skip warning: The provided center argument is ignored if expand is True + continue + yield SampleInput( bounding_box, format=bounding_box.format, @@ -297,6 +301,26 @@ def rotate_bounding_box(): ) +@register_kernel_info_from_sample_inputs_fn +def rotate_segmentation_mask(): + for mask, angle, expand, center in itertools.product( + make_segmentation_masks(extra_dims=((), (4,))), + [-87, 15, 90], # angle + [True, False], # expand + [None, [12, 23]], # center + ): + if center is not None and expand: + # Skip warning: The provided center argument is ignored if expand is True + continue + + yield SampleInput( + mask, + angle=angle, + expand=expand, + center=center, + ) + + @pytest.mark.parametrize( "kernel", [ @@ -411,8 +435,9 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): center=center, ) - if center is None: - center = [s // 2 for s in bboxes_image_size[::-1]] + center_ = center + if center_ is None: + center_ = [s * 0.5 for s in bboxes_image_size[::-1]] if bboxes.ndim < 2: bboxes = [bboxes] @@ -421,7 +446,7 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): for bbox in bboxes: bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) expected_bboxes.append( - _compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center) + _compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center_) ) if len(expected_bboxes) > 1: expected_bboxes = torch.stack(expected_bboxes) @@ -510,8 +535,10 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): shear=(shear, shear), center=center, ) - if center is None: - center = [s // 2 for s in mask.shape[-2:][::-1]] + + center_ = center + if center_ is None: + center_ = [s * 0.5 for s in mask.shape[-2:][::-1]] if mask.ndim < 4: masks = [mask] @@ -520,7 +547,7 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): expected_masks = [] for mask in masks: - expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center) + expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center_) expected_masks.append(expected_mask) if len(expected_masks) > 1: expected_masks = torch.stack(expected_masks) @@ -550,8 +577,7 @@ def test_correctness_affine_segmentation_mask_on_fixed_input(device): @pytest.mark.parametrize("angle", range(-90, 90, 56)) -@pytest.mark.parametrize("expand", [True, False]) -@pytest.mark.parametrize("center", [None, (12, 14)]) +@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))]) def test_correctness_rotate_bounding_box(angle, expand, center): def _compute_expected_bbox(bbox, angle_, expand_, center_): affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_) @@ -620,8 +646,9 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): center=center, ) - if center is None: - center = [s // 2 for s in bboxes_image_size[::-1]] + center_ = center + if center_ is None: + center_ = [s * 0.5 for s in bboxes_image_size[::-1]] if bboxes.ndim < 2: bboxes = [bboxes] @@ -629,7 +656,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): expected_bboxes = [] for bbox in bboxes: bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) - expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center)) + expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center_)) if len(expected_bboxes) > 1: expected_bboxes = torch.stack(expected_bboxes) else: @@ -638,7 +665,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): @pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2, analysis in progress +@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2 def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): # Check transformation against known expected output image_size = (64, 64) @@ -689,3 +716,91 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): ) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) + + +@pytest.mark.parametrize("angle", range(-90, 90, 37)) +@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))]) +def test_correctness_rotate_segmentation_mask(angle, expand, center): + def _compute_expected_mask(mask, angle_, expand_, center_): + assert mask.ndim == 3 and mask.shape[0] == 1 + image_size = mask.shape[-2:] + affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_) + inv_affine_matrix = np.linalg.inv(affine_matrix) + + if expand_: + # Pillow implementation on how to perform expand: + # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054-L2069 + height, width = image_size + points = np.array( + [ + [0.0, 0.0, 1.0], + [0.0, 1.0 * height, 1.0], + [1.0 * width, 1.0 * height, 1.0], + [1.0 * width, 0.0, 1.0], + ] + ) + new_points = points @ inv_affine_matrix.T + min_vals = np.min(new_points, axis=0)[:2] + max_vals = np.max(new_points, axis=0)[:2] + cmax = np.ceil(np.trunc(max_vals * 1e4) * 1e-4) + cmin = np.floor(np.trunc((min_vals + 1e-8) * 1e4) * 1e-4) + new_width, new_height = (cmax - cmin).astype("int32").tolist() + tr = np.array([-(new_width - width) / 2.0, -(new_height - height) / 2.0, 1.0]) @ inv_affine_matrix.T + + inv_affine_matrix[:2, 2] = tr[:2] + image_size = [new_height, new_width] + + inv_affine_matrix = inv_affine_matrix[:2, :] + expected_mask = torch.zeros(1, *image_size, dtype=mask.dtype) + + for out_y in range(expected_mask.shape[1]): + for out_x in range(expected_mask.shape[2]): + output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0]) + input_pt = np.floor(np.dot(inv_affine_matrix, output_pt)).astype(np.int32) + in_x, in_y = input_pt[:2] + if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]: + expected_mask[0, out_y, out_x] = mask[0, in_y, in_x] + return expected_mask.to(mask.device) + + for mask in make_segmentation_masks(extra_dims=((), (4,))): + output_mask = F.rotate_segmentation_mask( + mask, + angle=angle, + expand=expand, + center=center, + ) + + center_ = center + if center_ is None: + center_ = [s * 0.5 for s in mask.shape[-2:][::-1]] + + if mask.ndim < 4: + masks = [mask] + else: + masks = [m for m in mask] + + expected_masks = [] + for mask in masks: + expected_mask = _compute_expected_mask(mask, -angle, expand, center_) + expected_masks.append(expected_mask) + if len(expected_masks) > 1: + expected_masks = torch.stack(expected_masks) + else: + expected_masks = expected_masks[0] + torch.testing.assert_close(output_mask, expected_masks) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_rotate_segmentation_mask_on_fixed_input(device): + # Check transformation against known expected output and CPU/CUDA devices + + # Create a fixed input segmentation mask with 2 square masks + # in top-left, bottom-left corners + mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device) + mask[0, 2:10, 2:10] = 1 + mask[0, 32 - 9 : 32 - 3, 3:9] = 2 + + # Rotate 90 degrees + expected_mask = torch.rot90(mask, k=1, dims=(-2, -1)) + out_mask = F.rotate_segmentation_mask(mask, 90, expand=False) + torch.testing.assert_close(out_mask, expected_mask) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 51bf73a18f7..e8f25342a18 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -56,6 +56,7 @@ rotate_bounding_box, rotate_image_tensor, rotate_image_pil, + rotate_segmentation_mask, pad_image_tensor, pad_image_pil, pad_bounding_box, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 71882f06270..7629766c0e2 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -324,7 +324,7 @@ def rotate_image_tensor( center_f = [0.0, 0.0] if center is not None: if expand: - warnings.warn("The provided center argument is ignored if expand is True") + warnings.warn("The provided center argument has no effect on the result if expand is True") else: _, height, width = get_dimensions_image_tensor(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. @@ -345,7 +345,7 @@ def rotate_image_pil( center: Optional[List[float]] = None, ) -> PIL.Image.Image: if center is not None and expand: - warnings.warn("The provided center argument is ignored if expand is True") + warnings.warn("The provided center argument has no effect on the result if expand is True") center = None return _FP.rotate( @@ -361,6 +361,10 @@ def rotate_bounding_box( expand: bool = False, center: Optional[List[float]] = None, ) -> torch.Tensor: + if center is not None and expand: + warnings.warn("The provided center argument has no effect on the result if expand is True") + center = None + original_shape = bounding_box.shape bounding_box = convert_bounding_box_format( bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY @@ -373,6 +377,21 @@ def rotate_bounding_box( ).view(original_shape) +def rotate_segmentation_mask( + img: torch.Tensor, + angle: float, + expand: bool = False, + center: Optional[List[float]] = None, +) -> torch.Tensor: + return rotate_image_tensor( + img, + angle=angle, + expand=expand, + interpolation=InterpolationMode.NEAREST, + center=center, + ) + + pad_image_tensor = _FT.pad pad_image_pil = _FP.pad