diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 3876beea5c4..2c8540f093c 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -330,6 +330,10 @@ def rotate_segmentation_mask(): and callable(kernel) and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"}) and "pil" not in name + and name + not in { + "to_image_tensor", + } ], ) def test_scriptable(kernel): diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 94a642cb691..a4e845d75c3 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -22,4 +22,4 @@ from ._misc import Identity, Normalize, ToDtype, Lambda from ._type_conversion import DecodeImage, LabelToOneHot -from ._legacy import Grayscale, RandomGrayscale # usort: skip +from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip diff --git a/torchvision/prototype/transforms/_legacy.py b/torchvision/prototype/transforms/_deprecated.py similarity index 60% rename from torchvision/prototype/transforms/_legacy.py rename to torchvision/prototype/transforms/_deprecated.py index 2a8cce55886..b9b712ebcae 100644 --- a/torchvision/prototype/transforms/_legacy.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -1,14 +1,63 @@ -from __future__ import annotations - import warnings -from typing import Any, Dict +from typing import Any, Dict, Optional +import numpy as np +import PIL.Image +from torchvision.prototype import features from torchvision.prototype.features import ColorSpace from torchvision.prototype.transforms import Transform +from torchvision.transforms import functional as _F from typing_extensions import Literal from ._meta import ConvertImageColorSpace from ._transform import _RandomApplyTransform +from ._utils import is_simple_tensor + + +class ToTensor(Transform): + def __init__(self) -> None: + warnings.warn( + "The transform `ToTensor()` is deprecated and will be removed in a future release. " + "Instead, please use `transforms.ToImageTensor()`." + ) + super().__init__() + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, (PIL.Image.Image, np.ndarray)): + return _F.to_tensor(input) + else: + return input + + +class PILToTensor(Transform): + def __init__(self) -> None: + warnings.warn( + "The transform `PILToTensor()` is deprecated and will be removed in a future release. " + "Instead, please use `transforms.ToImageTensor()`." + ) + super().__init__() + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, PIL.Image.Image): + return _F.pil_to_tensor(input) + else: + return input + + +class ToPILImage(Transform): + def __init__(self, mode: Optional[str] = None) -> None: + warnings.warn( + "The transform `ToPILImage()` is deprecated and will be removed in a future release. " + "Instead, please use `transforms.ToImagePIL()`." + ) + super().__init__() + self.mode = mode + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if is_simple_tensor(input) or isinstance(input, (features.Image, np.ndarray)): + return _F.to_pil_image(input, mode=self.mode) + else: + return input class Grayscale(Transform): diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index f2dc426897b..09c071a27e0 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -1,8 +1,12 @@ from typing import Any, Dict +import numpy as np +import PIL.Image from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F +from ._utils import is_simple_tensor + class DecodeImage(Transform): def _transform(self, input: Any, params: Dict[str, Any]) -> Any: @@ -33,3 +37,28 @@ def extra_repr(self) -> str: return "" return f"num_categories={self.num_categories}" + + +class ToImageTensor(Transform): + def __init__(self, *, copy: bool = False) -> None: + super().__init__() + self.copy = copy + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(input): + output = F.to_image_tensor(input, copy=self.copy) + return features.Image(output) + else: + return input + + +class ToImagePIL(Transform): + def __init__(self, *, copy: bool = False) -> None: + super().__init__() + self.copy = copy + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(input): + return F.to_image_pil(input, copy=self.copy) + else: + return input diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index e8f25342a18..64d47958b96 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -74,4 +74,10 @@ ten_crop_image_pil, ) from ._misc import normalize_image_tensor, gaussian_blur_image_tensor -from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot +from ._type_conversion import ( + decode_image_with_pil, + decode_video_with_av, + label_to_one_hot, + to_image_tensor, + to_image_pil, +) diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/prototype/transforms/functional/_type_conversion.py index 06b2daaf6f1..37f8f9b70a3 100644 --- a/torchvision/prototype/transforms/functional/_type_conversion.py +++ b/torchvision/prototype/transforms/functional/_type_conversion.py @@ -1,5 +1,5 @@ import unittest.mock -from typing import Dict, Any, Tuple +from typing import Dict, Any, Tuple, Union import numpy as np import PIL.Image @@ -7,6 +7,7 @@ from torch.nn.functional import one_hot from torchvision.io.video import read_video from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer +from torchvision.transforms import functional as _F def decode_image_with_pil(encoded_image: torch.Tensor) -> torch.Tensor: @@ -23,3 +24,23 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor: return one_hot(label, num_classes=num_categories) # type: ignore[no-any-return] + + +def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor: + if isinstance(image, torch.Tensor): + if copy: + return image.clone() + else: + return image + + return _F.to_tensor(image) + + +def to_image_pil(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> PIL.Image.Image: + if isinstance(image, PIL.Image.Image): + if copy: + return image.copy() + else: + return image + + return _F.to_pil_image(to_image_tensor(image, copy=False)) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 230ad67f683..e964b10e18e 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -120,7 +120,7 @@ def _is_numpy_image(img: Any) -> bool: return img.ndim in {2, 3} -def to_tensor(pic): +def to_tensor(pic) -> Tensor: """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This function does not support torchscript.