From d675ff45dfb8e73eb71441759f5207eb6fa58306 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 5 Dec 2022 13:57:25 +0100 Subject: [PATCH] move Datapoint out of public namespace --- test/test_prototype_transforms_functional.py | 2 +- torchvision/prototype/datapoints/__init__.py | 2 +- .../prototype/datasets/_builtin/caltech.py | 3 +- .../prototype/datasets/_builtin/celeba.py | 3 +- .../prototype/datasets/_builtin/coco.py | 3 +- .../prototype/datasets/_builtin/cub200.py | 3 +- .../prototype/datasets/_builtin/sbd.py | 2 +- .../prototype/transforms/functional/_color.py | 66 ++++++++++++------- .../transforms/functional/_geometry.py | 66 ++++++++++++------- .../prototype/transforms/functional/_meta.py | 4 +- .../prototype/transforms/functional/_misc.py | 6 +- torchvision/prototype/transforms/utils.py | 2 +- 12 files changed, 107 insertions(+), 55 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 8ad55722378..7cd84fbcd61 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -409,7 +409,7 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, feature @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id) def test_dispatcher_feature_signatures_consistency(self, info): try: - feature_method = getattr(datapoints.Datapoint, info.id) + feature_method = getattr(datapoints._datapoint.Datapoint, info.id) except AttributeError: pytest.skip("Dispatcher doesn't support arbitrary feature dispatch.") diff --git a/torchvision/prototype/datapoints/__init__.py b/torchvision/prototype/datapoints/__init__.py index 6b2ff438b00..92f345e20bd 100644 --- a/torchvision/prototype/datapoints/__init__.py +++ b/torchvision/prototype/datapoints/__init__.py @@ -1,5 +1,5 @@ from ._bounding_box import BoundingBox, BoundingBoxFormat -from ._datapoint import Datapoint, FillType, FillTypeJIT, InputType, InputTypeJIT +from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT from ._image import ColorSpace, Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT from ._label import Label, OneHotLabel from ._mask import Mask diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index d675bfdb96f..55a77c1a920 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -4,7 +4,8 @@ import numpy as np from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper -from torchvision.prototype.datapoints import BoundingBox, Datapoint, Label +from torchvision.prototype.datapoints import BoundingBox, Label +from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index 1fe977f6cf6..3082a28c30c 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -3,7 +3,8 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper -from torchvision.prototype.datapoints import BoundingBox, Datapoint, Label +from torchvision.prototype.datapoints import BoundingBox, Label +from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 6744e1dfed0..fa68bf4dc6f 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -14,7 +14,8 @@ Mapper, UnBatcher, ) -from torchvision.prototype.datapoints import BoundingBox, Datapoint, Label, Mask +from torchvision.prototype.datapoints import BoundingBox, Label, Mask +from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index 615f4b1f528..ea192baf650 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -14,7 +14,8 @@ Mapper, ) from torchdata.datapipes.map import IterToMapConverter -from torchvision.prototype.datapoints import BoundingBox, Datapoint, Label +from torchvision.prototype.datapoints import BoundingBox, Label +from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index f8b88a214c3..c9f054b2c9e 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -4,7 +4,7 @@ import numpy as np from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper -from torchvision.prototype.datapoints import Datapoint +from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index a9c4011f1ae..618968cbb48 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -38,9 +38,11 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_brightness(brightness_factor=brightness_factor) elif isinstance(inpt, PIL.Image.Image): return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) @@ -77,9 +79,11 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to def adjust_saturation(inpt: datapoints.InputTypeJIT, saturation_factor: float) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_saturation(saturation_factor=saturation_factor) elif isinstance(inpt, PIL.Image.Image): return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) @@ -116,9 +120,11 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_contrast(contrast_factor=contrast_factor) elif isinstance(inpt, PIL.Image.Image): return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) @@ -189,9 +195,11 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc def adjust_sharpness(inpt: datapoints.InputTypeJIT, sharpness_factor: float) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_sharpness(sharpness_factor=sharpness_factor) elif isinstance(inpt, PIL.Image.Image): return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) @@ -301,9 +309,11 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_hue(hue_factor=hue_factor) elif isinstance(inpt, PIL.Image.Image): return adjust_hue_image_pil(inpt, hue_factor=hue_factor) @@ -341,9 +351,11 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_gamma(gamma=gamma, gain=gain) elif isinstance(inpt, PIL.Image.Image): return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) @@ -375,9 +387,11 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return posterize_image_tensor(inpt, bits=bits) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.posterize(bits=bits) elif isinstance(inpt, PIL.Image.Image): return posterize_image_pil(inpt, bits=bits) @@ -403,9 +417,11 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return solarize_image_tensor(inpt, threshold=threshold) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.solarize(threshold=threshold) elif isinstance(inpt, PIL.Image.Image): return solarize_image_pil(inpt, threshold=threshold) @@ -453,9 +469,11 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor: def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return autocontrast_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.autocontrast() elif isinstance(inpt, PIL.Image.Image): return autocontrast_image_pil(inpt) @@ -543,9 +561,11 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return equalize_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.equalize() elif isinstance(inpt, PIL.Image.Image): return equalize_image_pil(inpt) @@ -574,9 +594,11 @@ def invert_video(video: torch.Tensor) -> torch.Tensor: def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return invert_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.invert() elif isinstance(inpt, PIL.Image.Image): return invert_image_pil(inpt) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 0c9d26ebca5..cef68d66ee9 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -55,9 +55,11 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return horizontal_flip_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.horizontal_flip() elif isinstance(inpt, PIL.Image.Image): return horizontal_flip_image_pil(inpt) @@ -101,9 +103,11 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return vertical_flip_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.vertical_flip() elif isinstance(inpt, PIL.Image.Image): return vertical_flip_image_pil(inpt) @@ -227,9 +231,11 @@ def resize( max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias) elif isinstance(inpt, PIL.Image.Image): if antialias is not None and not antialias: @@ -725,7 +731,9 @@ def affine( center: Optional[List[float]] = None, ) -> datapoints.InputTypeJIT: # TODO: consider deprecating integers from angle and shear on the future - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return affine_image_tensor( inpt, angle, @@ -736,7 +744,7 @@ def affine( fill=fill, center=center, ) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.affine( angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center ) @@ -905,9 +913,11 @@ def rotate( center: Optional[List[float]] = None, fill: datapoints.FillTypeJIT = None, ) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center) elif isinstance(inpt, PIL.Image.Image): return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) @@ -1110,10 +1120,12 @@ def pad( fill: datapoints.FillTypeJIT = None, padding_mode: str = "constant", ) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.pad(padding, fill=fill, padding_mode=padding_mode) elif isinstance(inpt, PIL.Image.Image): return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) @@ -1185,9 +1197,11 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return crop_image_tensor(inpt, top, left, height, width) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.crop(top, left, height, width) elif isinstance(inpt, PIL.Image.Image): return crop_image_pil(inpt, top, left, height, width) @@ -1438,11 +1452,13 @@ def perspective( fill: datapoints.FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return perspective_image_tensor( inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients ) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.perspective( startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients ) @@ -1596,9 +1612,11 @@ def elastic( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: datapoints.FillTypeJIT = None, ) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.elastic(displacement, interpolation=interpolation, fill=fill) elif isinstance(inpt, PIL.Image.Image): return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill) @@ -1706,9 +1724,11 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return center_crop_image_tensor(inpt, output_size) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.center_crop(output_size) elif isinstance(inpt, PIL.Image.Image): return center_crop_image_pil(inpt, output_size) @@ -1797,11 +1817,13 @@ def resized_crop( interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: Optional[bool] = None, ) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return resized_crop_image_tensor( inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation ) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation) elif isinstance(inpt, PIL.Image.Image): return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index f3d0008fba5..a6b9c773891 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -109,7 +109,9 @@ def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[ def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return get_spatial_size_image_tensor(inpt) elif isinstance(inpt, (datapoints.Image, datapoints.Video, datapoints.BoundingBox, datapoints.Mask)): return list(inpt.spatial_size) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 59f1adbc1ab..7799187373f 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -168,9 +168,11 @@ def gaussian_blur_video( def gaussian_blur( inpt: datapoints.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> datapoints.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) - elif isinstance(inpt, datapoints.Datapoint): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma) elif isinstance(inpt, PIL.Image.Image): return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma) diff --git a/torchvision/prototype/transforms/utils.py b/torchvision/prototype/transforms/utils.py index 0698e88fa57..9ab2ed2602b 100644 --- a/torchvision/prototype/transforms/utils.py +++ b/torchvision/prototype/transforms/utils.py @@ -7,7 +7,7 @@ from torchvision._utils import sequence_to_str from torchvision.prototype import datapoints -from torchvision.prototype.datapoints import Datapoint +from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.prototype.transforms.functional import get_dimensions, get_spatial_size