diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 8648a09ad94..33c390f9fe7 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -17,7 +17,7 @@ from torch.nn.functional import one_hot from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import datapoints -from torchvision.transforms.functional_tensor import _max_value as get_max_value +from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_tensor __all__ = [ diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index a0f7da5e262..5e8be39aef6 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -29,7 +29,7 @@ ) from torch.utils._pytree import tree_map from torchvision import datapoints -from torchvision.transforms.functional_tensor import _max_value as get_max_value, _parse_pad_padding +from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding __all__ = ["KernelInfo", "KERNEL_INFOS"] diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index e7c845dc8e6..0e1cc648a19 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -11,9 +11,9 @@ import pytest import torch import torchvision.transforms as T +import torchvision.transforms._functional_pil as F_pil +import torchvision.transforms._functional_tensor as F_t import torchvision.transforms.functional as F -import torchvision.transforms.functional_pil as F_pil -import torchvision.transforms.functional_tensor as F_t from common_utils import ( _assert_approx_equal_tensor_to_pil, _assert_equal_tensor_to_pil, diff --git a/test/test_transforms.py b/test/test_transforms.py index 57e61bbad70..cec69c0cbe9 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -9,8 +9,8 @@ import pytest import torch import torchvision.transforms as transforms +import torchvision.transforms._functional_tensor as F_t import torchvision.transforms.functional as F -import torchvision.transforms.functional_tensor as F_t from PIL import Image from torch._utils_internal import get_file_path_2 diff --git a/torchvision/transforms/_functional_pil.py b/torchvision/transforms/_functional_pil.py new file mode 100644 index 00000000000..120998d0072 --- /dev/null +++ b/torchvision/transforms/_functional_pil.py @@ -0,0 +1,391 @@ +import numbers +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from PIL import Image, ImageEnhance, ImageOps + +try: + import accimage +except ImportError: + accimage = None + + +@torch.jit.unused +def _is_pil_image(img: Any) -> bool: + if accimage is not None: + return isinstance(img, (Image.Image, accimage.Image)) + else: + return isinstance(img, Image.Image) + + +@torch.jit.unused +def get_dimensions(img: Any) -> List[int]: + if _is_pil_image(img): + if hasattr(img, "getbands"): + channels = len(img.getbands()) + else: + channels = img.channels + width, height = img.size + return [channels, height, width] + raise TypeError(f"Unexpected type {type(img)}") + + +@torch.jit.unused +def get_image_size(img: Any) -> List[int]: + if _is_pil_image(img): + return list(img.size) + raise TypeError(f"Unexpected type {type(img)}") + + +@torch.jit.unused +def get_image_num_channels(img: Any) -> int: + if _is_pil_image(img): + if hasattr(img, "getbands"): + return len(img.getbands()) + else: + return img.channels + raise TypeError(f"Unexpected type {type(img)}") + + +@torch.jit.unused +def hflip(img: Image.Image) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + return img.transpose(Image.FLIP_LEFT_RIGHT) + + +@torch.jit.unused +def vflip(img: Image.Image) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + return img.transpose(Image.FLIP_TOP_BOTTOM) + + +@torch.jit.unused +def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness_factor) + return img + + +@torch.jit.unused +def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast_factor) + return img + + +@torch.jit.unused +def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + enhancer = ImageEnhance.Color(img) + img = enhancer.enhance(saturation_factor) + return img + + +@torch.jit.unused +def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image: + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + input_mode = img.mode + if input_mode in {"L", "1", "I", "F"}: + return img + + h, s, v = img.convert("HSV").split() + + np_h = np.array(h, dtype=np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over="ignore"): + np_h += np.uint8(hue_factor * 255) + h = Image.fromarray(np_h, "L") + + img = Image.merge("HSV", (h, s, v)).convert(input_mode) + return img + + +@torch.jit.unused +def adjust_gamma( + img: Image.Image, + gamma: float, + gain: float = 1.0, +) -> Image.Image: + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + if gamma < 0: + raise ValueError("Gamma should be a non-negative real number") + + input_mode = img.mode + img = img.convert("RGB") + gamma_map = [int((255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma)) for ele in range(256)] * 3 + img = img.point(gamma_map) # use PIL's point-function to accelerate this part + + img = img.convert(input_mode) + return img + + +@torch.jit.unused +def pad( + img: Image.Image, + padding: Union[int, List[int], Tuple[int, ...]], + fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0, + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", +) -> Image.Image: + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate padding arg") + if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate fill arg") + if not isinstance(padding_mode, str): + raise TypeError("Got inappropriate padding_mode arg") + + if isinstance(padding, list): + padding = tuple(padding) + + if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]: + raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") + + if isinstance(padding, tuple) and len(padding) == 1: + # Compatibility with `functional_tensor.pad` + padding = padding[0] + + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + + if padding_mode == "constant": + opts = _parse_fill(fill, img, name="fill") + if img.mode == "P": + palette = img.getpalette() + image = ImageOps.expand(img, border=padding, **opts) + image.putpalette(palette) + return image + + return ImageOps.expand(img, border=padding, **opts) + else: + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + if isinstance(padding, tuple) and len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + if isinstance(padding, tuple) and len(padding) == 4: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + p = [pad_left, pad_top, pad_right, pad_bottom] + cropping = -np.minimum(p, 0) + + if cropping.any(): + crop_left, crop_top, crop_right, crop_bottom = cropping + img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom)) + + pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0) + + if img.mode == "P": + palette = img.getpalette() + img = np.asarray(img) + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode) + img = Image.fromarray(img) + img.putpalette(palette) + return img + + img = np.asarray(img) + # RGB image + if len(img.shape) == 3: + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode) + # Grayscale image + if len(img.shape) == 2: + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) + + return Image.fromarray(img) + + +@torch.jit.unused +def crop( + img: Image.Image, + top: int, + left: int, + height: int, + width: int, +) -> Image.Image: + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + return img.crop((left, top, left + width, top + height)) + + +@torch.jit.unused +def resize( + img: Image.Image, + size: Union[List[int], int], + interpolation: int = Image.BILINEAR, +) -> Image.Image: + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + if not (isinstance(size, list) and len(size) == 2): + raise TypeError(f"Got inappropriate size arg: {size}") + + return img.resize(tuple(size[::-1]), interpolation) + + +@torch.jit.unused +def _parse_fill( + fill: Optional[Union[float, List[float], Tuple[float, ...]]], + img: Image.Image, + name: str = "fillcolor", +) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]: + + # Process fill color for affine transforms + num_channels = get_image_num_channels(img) + if fill is None: + fill = 0 + if isinstance(fill, (int, float)) and num_channels > 1: + fill = tuple([fill] * num_channels) + if isinstance(fill, (list, tuple)): + if len(fill) != num_channels: + msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})" + raise ValueError(msg.format(len(fill), num_channels)) + + fill = tuple(fill) + + if img.mode != "F": + if isinstance(fill, (list, tuple)): + fill = tuple(int(x) for x in fill) + else: + fill = int(fill) + + return {name: fill} + + +@torch.jit.unused +def affine( + img: Image.Image, + matrix: List[float], + interpolation: int = Image.NEAREST, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, +) -> Image.Image: + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + output_size = img.size + opts = _parse_fill(fill, img) + return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts) + + +@torch.jit.unused +def rotate( + img: Image.Image, + angle: float, + interpolation: int = Image.NEAREST, + expand: bool = False, + center: Optional[Tuple[int, int]] = None, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, +) -> Image.Image: + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + opts = _parse_fill(fill, img) + return img.rotate(angle, interpolation, expand, center, **opts) + + +@torch.jit.unused +def perspective( + img: Image.Image, + perspective_coeffs: List[float], + interpolation: int = Image.BICUBIC, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, +) -> Image.Image: + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + opts = _parse_fill(fill, img) + + return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts) + + +@torch.jit.unused +def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + if num_output_channels == 1: + img = img.convert("L") + elif num_output_channels == 3: + img = img.convert("L") + np_img = np.array(img, dtype=np.uint8) + np_img = np.dstack([np_img, np_img, np_img]) + img = Image.fromarray(np_img, "RGB") + else: + raise ValueError("num_output_channels should be either 1 or 3") + + return img + + +@torch.jit.unused +def invert(img: Image.Image) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + return ImageOps.invert(img) + + +@torch.jit.unused +def posterize(img: Image.Image, bits: int) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + return ImageOps.posterize(img, bits) + + +@torch.jit.unused +def solarize(img: Image.Image, threshold: int) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + return ImageOps.solarize(img, threshold) + + +@torch.jit.unused +def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + enhancer = ImageEnhance.Sharpness(img) + img = enhancer.enhance(sharpness_factor) + return img + + +@torch.jit.unused +def autocontrast(img: Image.Image) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + return ImageOps.autocontrast(img) + + +@torch.jit.unused +def equalize(img: Image.Image) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + return ImageOps.equalize(img) diff --git a/torchvision/transforms/_functional_tensor.py b/torchvision/transforms/_functional_tensor.py new file mode 100644 index 00000000000..d0e7c17882b --- /dev/null +++ b/torchvision/transforms/_functional_tensor.py @@ -0,0 +1,962 @@ +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad + + +def _is_tensor_a_torch_image(x: Tensor) -> bool: + return x.ndim >= 2 + + +def _assert_image_tensor(img: Tensor) -> None: + if not _is_tensor_a_torch_image(img): + raise TypeError("Tensor is not a torch image.") + + +def get_dimensions(img: Tensor) -> List[int]: + _assert_image_tensor(img) + channels = 1 if img.ndim == 2 else img.shape[-3] + height, width = img.shape[-2:] + return [channels, height, width] + + +def get_image_size(img: Tensor) -> List[int]: + # Returns (w, h) of tensor image + _assert_image_tensor(img) + return [img.shape[-1], img.shape[-2]] + + +def get_image_num_channels(img: Tensor) -> int: + _assert_image_tensor(img) + if img.ndim == 2: + return 1 + elif img.ndim > 2: + return img.shape[-3] + + raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}") + + +def _max_value(dtype: torch.dtype) -> int: + if dtype == torch.uint8: + return 255 + elif dtype == torch.int8: + return 127 + elif dtype == torch.int16: + return 32767 + elif dtype == torch.int32: + return 2147483647 + elif dtype == torch.int64: + return 9223372036854775807 + else: + # This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not + # easy. + return 1 + + +def _assert_channels(img: Tensor, permitted: List[int]) -> None: + c = get_dimensions(img)[0] + if c not in permitted: + raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}") + + +def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: + if image.dtype == dtype: + return image + + if image.is_floating_point(): + + # TODO: replace with dtype.is_floating_point when torchscript supports it + if torch.tensor(0, dtype=dtype).is_floating_point(): + return image.to(dtype) + + # float to int + if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( + image.dtype == torch.float64 and dtype == torch.int64 + ): + msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." + raise RuntimeError(msg) + + # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 + # For data in the range 0-1, (float * 255).to(uint) is only 255 + # when float is exactly 1.0. + # `max + 1 - epsilon` provides more evenly distributed mapping of + # ranges of floats to ints. + eps = 1e-3 + max_val = float(_max_value(dtype)) + result = image.mul(max_val + 1.0 - eps) + return result.to(dtype) + else: + input_max = float(_max_value(image.dtype)) + + # int to float + # TODO: replace with dtype.is_floating_point when torchscript supports it + if torch.tensor(0, dtype=dtype).is_floating_point(): + image = image.to(dtype) + return image / input_max + + output_max = float(_max_value(dtype)) + + # int to int + if input_max > output_max: + # factor should be forced to int for torch jit script + # otherwise factor is a float and image // factor can produce different results + factor = int((input_max + 1) // (output_max + 1)) + image = torch.div(image, factor, rounding_mode="floor") + return image.to(dtype) + else: + # factor should be forced to int for torch jit script + # otherwise factor is a float and image * factor can produce different results + factor = int((output_max + 1) // (input_max + 1)) + image = image.to(dtype) + return image * factor + + +def vflip(img: Tensor) -> Tensor: + _assert_image_tensor(img) + + return img.flip(-2) + + +def hflip(img: Tensor) -> Tensor: + _assert_image_tensor(img) + + return img.flip(-1) + + +def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: + _assert_image_tensor(img) + + _, h, w = get_dimensions(img) + right = left + width + bottom = top + height + + if left < 0 or top < 0 or right > w or bottom > h: + padding_ltrb = [ + max(-left + min(0, right), 0), + max(-top + min(0, bottom), 0), + max(right - max(w, left), 0), + max(bottom - max(h, top), 0), + ] + return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0) + return img[..., top:bottom, left:right] + + +def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: + if img.ndim < 3: + raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") + _assert_channels(img, [1, 3]) + + if num_output_channels not in (1, 3): + raise ValueError("num_output_channels should be either 1 or 3") + + if img.shape[-3] == 3: + r, g, b = img.unbind(dim=-3) + # This implementation closely follows the TF one: + # https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138 + l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype) + l_img = l_img.unsqueeze(dim=-3) + else: + l_img = img.clone() + + if num_output_channels == 3: + return l_img.expand(img.shape) + + return l_img + + +def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: + if brightness_factor < 0: + raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") + + _assert_image_tensor(img) + + _assert_channels(img, [1, 3]) + + return _blend(img, torch.zeros_like(img), brightness_factor) + + +def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: + if contrast_factor < 0: + raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") + + _assert_image_tensor(img) + + _assert_channels(img, [3, 1]) + c = get_dimensions(img)[0] + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + if c == 3: + mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True) + else: + mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True) + + return _blend(img, mean, contrast_factor) + + +def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") + + if not (isinstance(img, torch.Tensor)): + raise TypeError("Input img should be Tensor image") + + _assert_image_tensor(img) + + _assert_channels(img, [1, 3]) + if get_dimensions(img)[0] == 1: # Match PIL behaviour + return img + + orig_dtype = img.dtype + img = convert_image_dtype(img, torch.float32) + + img = _rgb2hsv(img) + h, s, v = img.unbind(dim=-3) + h = (h + hue_factor) % 1.0 + img = torch.stack((h, s, v), dim=-3) + img_hue_adj = _hsv2rgb(img) + + return convert_image_dtype(img_hue_adj, orig_dtype) + + +def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: + if saturation_factor < 0: + raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") + + _assert_image_tensor(img) + + _assert_channels(img, [1, 3]) + + if get_dimensions(img)[0] == 1: # Match PIL behaviour + return img + + return _blend(img, rgb_to_grayscale(img), saturation_factor) + + +def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: + if not isinstance(img, torch.Tensor): + raise TypeError("Input img should be a Tensor.") + + _assert_channels(img, [1, 3]) + + if gamma < 0: + raise ValueError("Gamma should be a non-negative real number") + + result = img + dtype = img.dtype + if not torch.is_floating_point(img): + result = convert_image_dtype(result, torch.float32) + + result = (gain * result**gamma).clamp(0, 1) + + result = convert_image_dtype(result, dtype) + return result + + +def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: + ratio = float(ratio) + bound = _max_value(img1.dtype) + return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype) + + +def _rgb2hsv(img: Tensor) -> Tensor: + r, g, b = img.unbind(dim=-3) + + # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/ + # src/libImaging/Convert.c#L330 + maxc = torch.max(img, dim=-3).values + minc = torch.min(img, dim=-3).values + + # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN + # from happening in the results, because + # + S channel has division by `maxc`, which is zero only if `maxc = minc` + # + H channel has division by `(maxc - minc)`. + # + # Instead of overwriting NaN afterwards, we just prevent it from occurring, so + # we don't need to deal with it in case we save the NaN in a buffer in + # backprop, if it is ever supported, but it doesn't hurt to do so. + eqc = maxc == minc + + cr = maxc - minc + # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine. + ones = torch.ones_like(maxc) + s = cr / torch.where(eqc, ones, maxc) + # Note that `eqc => maxc = minc = r = g = b`. So the following calculation + # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it + # would not matter what values `rc`, `gc`, and `bc` have here, and thus + # replacing denominator with 1 when `eqc` is fine. + cr_divisor = torch.where(eqc, ones, cr) + rc = (maxc - r) / cr_divisor + gc = (maxc - g) / cr_divisor + bc = (maxc - b) / cr_divisor + + hr = (maxc == r) * (bc - gc) + hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc) + hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) + h = hr + hg + hb + h = torch.fmod((h / 6.0 + 1.0), 1.0) + return torch.stack((h, s, maxc), dim=-3) + + +def _hsv2rgb(img: Tensor) -> Tensor: + h, s, v = img.unbind(dim=-3) + i = torch.floor(h * 6.0) + f = (h * 6.0) - i + i = i.to(dtype=torch.int32) + + p = torch.clamp((v * (1.0 - s)), 0.0, 1.0) + q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0) + t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0) + i = i % 6 + + mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1) + + a1 = torch.stack((v, q, p, p, t, v), dim=-3) + a2 = torch.stack((t, v, v, q, p, p), dim=-3) + a3 = torch.stack((p, p, t, v, v, q), dim=-3) + a4 = torch.stack((a1, a2, a3), dim=-4) + + return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4) + + +def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor: + # padding is left, right, top, bottom + + # crop if needed + if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0: + neg_min_padding = [-min(x, 0) for x in padding] + crop_left, crop_right, crop_top, crop_bottom = neg_min_padding + img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right] + padding = [max(x, 0) for x in padding] + + in_sizes = img.size() + + _x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...] + left_indices = [i for i in range(padding[0] - 1, -1, -1)] # e.g. [3, 2, 1, 0] + right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3] + x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device) + + _y_indices = [i for i in range(in_sizes[-2])] + top_indices = [i for i in range(padding[2] - 1, -1, -1)] + bottom_indices = [-(i + 1) for i in range(padding[3])] + y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device) + + ndim = img.ndim + if ndim == 3: + return img[:, y_indices[:, None], x_indices[None, :]] + elif ndim == 4: + return img[:, :, y_indices[:, None], x_indices[None, :]] + else: + raise RuntimeError("Symmetric padding of N-D tensors are not supported yet") + + +def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: + if isinstance(padding, int): + if torch.jit.is_scripting(): + # This maybe unreachable + raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]") + pad_left = pad_right = pad_top = pad_bottom = padding + elif len(padding) == 1: + pad_left = pad_right = pad_top = pad_bottom = padding[0] + elif len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + else: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + return [pad_left, pad_right, pad_top, pad_bottom] + + +def pad( + img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant" +) -> Tensor: + _assert_image_tensor(img) + + if fill is None: + fill = 0 + + if not isinstance(padding, (int, tuple, list)): + raise TypeError("Got inappropriate padding arg") + if not isinstance(fill, (int, float)): + raise TypeError("Got inappropriate fill arg") + if not isinstance(padding_mode, str): + raise TypeError("Got inappropriate padding_mode arg") + + if isinstance(padding, tuple): + padding = list(padding) + + if isinstance(padding, list): + # TODO: Jit is failing on loading this op when scripted and saved + # https://github.com/pytorch/pytorch/issues/81100 + if len(padding) not in [1, 2, 4]: + raise ValueError( + f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" + ) + + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + + p = _parse_pad_padding(padding) + + if padding_mode == "edge": + # remap padding_mode str + padding_mode = "replicate" + elif padding_mode == "symmetric": + # route to another implementation + return _pad_symmetric(img, p) + + need_squeeze = False + if img.ndim < 4: + img = img.unsqueeze(dim=0) + need_squeeze = True + + out_dtype = img.dtype + need_cast = False + if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64): + # Here we temporarily cast input tensor to float + # until pytorch issue is resolved : + # https://github.com/pytorch/pytorch/issues/40763 + need_cast = True + img = img.to(torch.float32) + + if padding_mode in ("reflect", "replicate"): + img = torch_pad(img, p, mode=padding_mode) + else: + img = torch_pad(img, p, mode=padding_mode, value=float(fill)) + + if need_squeeze: + img = img.squeeze(dim=0) + + if need_cast: + img = img.to(out_dtype) + + return img + + +def resize( + img: Tensor, + size: List[int], + interpolation: str = "bilinear", + # TODO: in v0.17, change the default to True. This will a private function + # by then, so we don't care about warning here. + antialias: Optional[bool] = None, +) -> Tensor: + _assert_image_tensor(img) + + if isinstance(size, tuple): + size = list(size) + + if antialias is None: + antialias = False + + if antialias and interpolation not in ["bilinear", "bicubic"]: + # We manually set it to False to avoid an error downstream in interpolate() + # This behaviour is documented: the parameter is irrelevant for modes + # that are not bilinear or bicubic. We used to raise an error here, but + # now we don't as True is the default. + antialias = False + + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) + + # Define align_corners to avoid warnings + align_corners = False if interpolation in ["bilinear", "bicubic"] else None + + img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias) + + if interpolation == "bicubic" and out_dtype == torch.uint8: + img = img.clamp(min=0, max=255) + + img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype) + + return img + + +def _assert_grid_transform_inputs( + img: Tensor, + matrix: Optional[List[float]], + interpolation: str, + fill: Optional[Union[int, float, List[float]]], + supported_interpolation_modes: List[str], + coeffs: Optional[List[float]] = None, +) -> None: + + if not (isinstance(img, torch.Tensor)): + raise TypeError("Input img should be Tensor") + + _assert_image_tensor(img) + + if matrix is not None and not isinstance(matrix, list): + raise TypeError("Argument matrix should be a list") + + if matrix is not None and len(matrix) != 6: + raise ValueError("Argument matrix should have 6 float values") + + if coeffs is not None and len(coeffs) != 8: + raise ValueError("Argument coeffs should have 8 float values") + + if fill is not None and not isinstance(fill, (int, float, tuple, list)): + warnings.warn("Argument fill should be either int, float, tuple or list") + + # Check fill + num_channels = get_dimensions(img)[0] + if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels: + msg = ( + "The number of elements in 'fill' cannot broadcast to match the number of " + "channels of the image ({} != {})" + ) + raise ValueError(msg.format(len(fill), num_channels)) + + if interpolation not in supported_interpolation_modes: + raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input") + + +def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]: + need_squeeze = False + # make image NCHW + if img.ndim < 4: + img = img.unsqueeze(dim=0) + need_squeeze = True + + out_dtype = img.dtype + need_cast = False + if out_dtype not in req_dtypes: + need_cast = True + req_dtype = req_dtypes[0] + img = img.to(req_dtype) + return img, need_cast, need_squeeze, out_dtype + + +def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor: + if need_squeeze: + img = img.squeeze(dim=0) + + if need_cast: + if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): + # it is better to round before cast + img = torch.round(img) + img = img.to(out_dtype) + + return img + + +def _apply_grid_transform( + img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]] +) -> Tensor: + + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype]) + + if img.shape[0] > 1: + # Apply same grid to a batch of images + grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3]) + + # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice + if fill is not None: + mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device) + img = torch.cat((img, mask), dim=1) + + img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False) + + # Fill with required color + if fill is not None: + mask = img[:, -1:, :, :] # N * 1 * H * W + img = img[:, :-1, :, :] # N * C * H * W + mask = mask.expand_as(img) + fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1) + fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img) + if mode == "nearest": + mask = mask < 0.5 + img[mask] = fill_img[mask] + else: # 'bilinear' + img = img * mask + (1.0 - mask) * fill_img + + img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) + return img + + +def _gen_affine_grid( + theta: Tensor, + w: int, + h: int, + ow: int, + oh: int, +) -> Tensor: + # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/ + # AffineGridGenerator.cpp#L18 + # Difference with AffineGridGenerator is that: + # 1) we normalize grid values after applying theta + # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate + + d = 0.5 + base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device) + x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device) + base_grid[..., 0].copy_(x_grid) + y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1) + base_grid[..., 1].copy_(y_grid) + base_grid[..., 2].fill_(1) + + rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device) + output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta) + return output_grid.view(1, oh, ow, 2) + + +def affine( + img: Tensor, + matrix: List[float], + interpolation: str = "nearest", + fill: Optional[Union[int, float, List[float]]] = None, +) -> Tensor: + _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) + + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) + shape = img.shape + # grid will be generated on the same device as theta and img + grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) + return _apply_grid_transform(img, grid, interpolation, fill=fill) + + +def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: + + # Inspired of PIL implementation: + # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 + + # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. + # Points are shifted due to affine matrix torch convention about + # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5) + pts = torch.tensor( + [ + [-0.5 * w, -0.5 * h, 1.0], + [-0.5 * w, 0.5 * h, 1.0], + [0.5 * w, 0.5 * h, 1.0], + [0.5 * w, -0.5 * h, 1.0], + ] + ) + theta = torch.tensor(matrix, dtype=torch.float).view(2, 3) + new_pts = torch.matmul(pts, theta.T) + min_vals, _ = new_pts.min(dim=0) + max_vals, _ = new_pts.max(dim=0) + + # shift points to [0, w] and [0, h] interval to match PIL results + min_vals += torch.tensor((w * 0.5, h * 0.5)) + max_vals += torch.tensor((w * 0.5, h * 0.5)) + + # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0 + tol = 1e-4 + cmax = torch.ceil((max_vals / tol).trunc_() * tol) + cmin = torch.floor((min_vals / tol).trunc_() * tol) + size = cmax - cmin + return int(size[0]), int(size[1]) # w, h + + +def rotate( + img: Tensor, + matrix: List[float], + interpolation: str = "nearest", + expand: bool = False, + fill: Optional[Union[int, float, List[float]]] = None, +) -> Tensor: + _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) + w, h = img.shape[-1], img.shape[-2] + ow, oh = _compute_affine_output_size(matrix, w, h) if expand else (w, h) + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) + # grid will be generated on the same device as theta and img + grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) + + return _apply_grid_transform(img, grid, interpolation, fill=fill) + + +def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor: + # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/ + # src/libImaging/Geometry.c#L394 + + # + # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) + # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) + # + theta1 = torch.tensor( + [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device + ) + theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device) + + d = 0.5 + base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) + x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device) + base_grid[..., 0].copy_(x_grid) + y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1) + base_grid[..., 1].copy_(y_grid) + base_grid[..., 2].fill_(1) + + rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device) + output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1) + output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2)) + + output_grid = output_grid1 / output_grid2 - 1.0 + return output_grid.view(1, oh, ow, 2) + + +def perspective( + img: Tensor, + perspective_coeffs: List[float], + interpolation: str = "bilinear", + fill: Optional[Union[int, float, List[float]]] = None, +) -> Tensor: + if not (isinstance(img, torch.Tensor)): + raise TypeError("Input img should be Tensor.") + + _assert_image_tensor(img) + + _assert_grid_transform_inputs( + img, + matrix=None, + interpolation=interpolation, + fill=fill, + supported_interpolation_modes=["nearest", "bilinear"], + coeffs=perspective_coeffs, + ) + + ow, oh = img.shape[-1], img.shape[-2] + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device) + return _apply_grid_transform(img, grid, interpolation, fill=fill) + + +def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + kernel1d = pdf / pdf.sum() + + return kernel1d + + +def _get_gaussian_kernel2d( + kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device +) -> Tensor: + kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype) + kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype) + kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :]) + return kernel2d + + +def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor: + if not (isinstance(img, torch.Tensor)): + raise TypeError(f"img should be Tensor. Got {type(img)}") + + _assert_image_tensor(img) + + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device) + kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) + + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype]) + + # padding = (left, right, top, bottom) + padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] + img = torch_pad(img, padding, mode="reflect") + img = conv2d(img, kernel, groups=img.shape[-3]) + + img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) + return img + + +def invert(img: Tensor) -> Tensor: + + _assert_image_tensor(img) + + if img.ndim < 3: + raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") + + _assert_channels(img, [1, 3]) + + return _max_value(img.dtype) - img + + +def posterize(img: Tensor, bits: int) -> Tensor: + + _assert_image_tensor(img) + + if img.ndim < 3: + raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") + if img.dtype != torch.uint8: + raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}") + + _assert_channels(img, [1, 3]) + mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1) + return img & mask + + +def solarize(img: Tensor, threshold: float) -> Tensor: + + _assert_image_tensor(img) + + if img.ndim < 3: + raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") + + _assert_channels(img, [1, 3]) + + if threshold > _max_value(img.dtype): + raise TypeError("Threshold should be less than bound of img.") + + inverted_img = invert(img) + return torch.where(img >= threshold, inverted_img, img) + + +def _blurred_degenerate_image(img: Tensor) -> Tensor: + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + + kernel = torch.ones((3, 3), dtype=dtype, device=img.device) + kernel[1, 1] = 5.0 + kernel /= kernel.sum() + kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) + + result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype]) + result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3]) + result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype) + + result = img.clone() + result[..., 1:-1, 1:-1] = result_tmp + + return result + + +def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: + if sharpness_factor < 0: + raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.") + + _assert_image_tensor(img) + + _assert_channels(img, [1, 3]) + + if img.size(-1) <= 2 or img.size(-2) <= 2: + return img + + return _blend(img, _blurred_degenerate_image(img), sharpness_factor) + + +def autocontrast(img: Tensor) -> Tensor: + + _assert_image_tensor(img) + + if img.ndim < 3: + raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") + + _assert_channels(img, [1, 3]) + + bound = _max_value(img.dtype) + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + + minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype) + maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype) + scale = bound / (maximum - minimum) + eq_idxs = torch.isfinite(scale).logical_not() + minimum[eq_idxs] = 0 + scale[eq_idxs] = 1 + + return ((img - minimum) * scale).clamp(0, bound).to(img.dtype) + + +def _scale_channel(img_chan: Tensor) -> Tensor: + # TODO: we should expect bincount to always be faster than histc, but this + # isn't always the case. Once + # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if + # block and only use bincount. + if img_chan.is_cuda: + hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255) + else: + hist = torch.bincount(img_chan.reshape(-1), minlength=256) + + nonzero_hist = hist[hist != 0] + step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor") + if step == 0: + return img_chan + + lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor") + lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255) + + return lut[img_chan.to(torch.int64)].to(torch.uint8) + + +def _equalize_single_image(img: Tensor) -> Tensor: + return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))]) + + +def equalize(img: Tensor) -> Tensor: + + _assert_image_tensor(img) + + if not (3 <= img.ndim <= 4): + raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}") + if img.dtype != torch.uint8: + raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}") + + _assert_channels(img, [1, 3]) + + if img.ndim == 3: + return _equalize_single_image(img) + + return torch.stack([_equalize_single_image(x) for x in img]) + + +def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor: + _assert_image_tensor(tensor) + + if not tensor.is_floating_point(): + raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.") + + if tensor.ndim < 3: + raise ValueError( + f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}" + ) + + if not inplace: + tensor = tensor.clone() + + dtype = tensor.dtype + mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) + std = torch.as_tensor(std, dtype=dtype, device=tensor.device) + if (std == 0).any(): + raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.") + if mean.ndim == 1: + mean = mean.view(-1, 1, 1) + if std.ndim == 1: + std = std.view(-1, 1, 1) + return tensor.sub_(mean).div_(std) + + +def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: + _assert_image_tensor(img) + + if not inplace: + img = img.clone() + + img[..., i : i + h, j : j + w] = v + return img + + +def _create_identity_grid(size: List[int]) -> Tensor: + hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size] + grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij") + return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2 + + +def elastic_transform( + img: Tensor, + displacement: Tensor, + interpolation: str = "bilinear", + fill: Optional[Union[int, float, List[float]]] = None, +) -> Tensor: + + if not (isinstance(img, torch.Tensor)): + raise TypeError(f"img should be Tensor. Got {type(img)}") + + size = list(img.shape[-2:]) + displacement = displacement.to(img.device) + + identity_grid = _create_identity_grid(size) + grid = identity_grid.to(img.device) + displacement + return _apply_grid_transform(img, grid, interpolation, fill) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index c5b2a71d0d7..09c8bdbcfeb 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -15,7 +15,7 @@ accimage = None from ..utils import _log_api_usage_once -from . import functional_pil as F_pil, functional_tensor as F_t +from . import _functional_pil as F_pil, _functional_tensor as F_t class InterpolationMode(Enum): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 120998d0072..bfcbf1a5442 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -1,391 +1,11 @@ -import numbers -from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union +import warnings -import numpy as np -import torch -from PIL import Image, ImageEnhance, ImageOps +from torchvision.transforms._functional_pil import * # noqa -try: - import accimage -except ImportError: - accimage = None - - -@torch.jit.unused -def _is_pil_image(img: Any) -> bool: - if accimage is not None: - return isinstance(img, (Image.Image, accimage.Image)) - else: - return isinstance(img, Image.Image) - - -@torch.jit.unused -def get_dimensions(img: Any) -> List[int]: - if _is_pil_image(img): - if hasattr(img, "getbands"): - channels = len(img.getbands()) - else: - channels = img.channels - width, height = img.size - return [channels, height, width] - raise TypeError(f"Unexpected type {type(img)}") - - -@torch.jit.unused -def get_image_size(img: Any) -> List[int]: - if _is_pil_image(img): - return list(img.size) - raise TypeError(f"Unexpected type {type(img)}") - - -@torch.jit.unused -def get_image_num_channels(img: Any) -> int: - if _is_pil_image(img): - if hasattr(img, "getbands"): - return len(img.getbands()) - else: - return img.channels - raise TypeError(f"Unexpected type {type(img)}") - - -@torch.jit.unused -def hflip(img: Image.Image) -> Image.Image: - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - return img.transpose(Image.FLIP_LEFT_RIGHT) - - -@torch.jit.unused -def vflip(img: Image.Image) -> Image.Image: - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - return img.transpose(Image.FLIP_TOP_BOTTOM) - - -@torch.jit.unused -def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image: - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - enhancer = ImageEnhance.Brightness(img) - img = enhancer.enhance(brightness_factor) - return img - - -@torch.jit.unused -def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image: - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - enhancer = ImageEnhance.Contrast(img) - img = enhancer.enhance(contrast_factor) - return img - - -@torch.jit.unused -def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image: - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - enhancer = ImageEnhance.Color(img) - img = enhancer.enhance(saturation_factor) - return img - - -@torch.jit.unused -def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image: - if not (-0.5 <= hue_factor <= 0.5): - raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") - - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - input_mode = img.mode - if input_mode in {"L", "1", "I", "F"}: - return img - - h, s, v = img.convert("HSV").split() - - np_h = np.array(h, dtype=np.uint8) - # uint8 addition take cares of rotation across boundaries - with np.errstate(over="ignore"): - np_h += np.uint8(hue_factor * 255) - h = Image.fromarray(np_h, "L") - - img = Image.merge("HSV", (h, s, v)).convert(input_mode) - return img - - -@torch.jit.unused -def adjust_gamma( - img: Image.Image, - gamma: float, - gain: float = 1.0, -) -> Image.Image: - - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - if gamma < 0: - raise ValueError("Gamma should be a non-negative real number") - - input_mode = img.mode - img = img.convert("RGB") - gamma_map = [int((255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma)) for ele in range(256)] * 3 - img = img.point(gamma_map) # use PIL's point-function to accelerate this part - - img = img.convert(input_mode) - return img - - -@torch.jit.unused -def pad( - img: Image.Image, - padding: Union[int, List[int], Tuple[int, ...]], - fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0, - padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", -) -> Image.Image: - - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - if not isinstance(padding, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate padding arg") - if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate fill arg") - if not isinstance(padding_mode, str): - raise TypeError("Got inappropriate padding_mode arg") - - if isinstance(padding, list): - padding = tuple(padding) - - if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]: - raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") - - if isinstance(padding, tuple) and len(padding) == 1: - # Compatibility with `functional_tensor.pad` - padding = padding[0] - - if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: - raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") - - if padding_mode == "constant": - opts = _parse_fill(fill, img, name="fill") - if img.mode == "P": - palette = img.getpalette() - image = ImageOps.expand(img, border=padding, **opts) - image.putpalette(palette) - return image - - return ImageOps.expand(img, border=padding, **opts) - else: - if isinstance(padding, int): - pad_left = pad_right = pad_top = pad_bottom = padding - if isinstance(padding, tuple) and len(padding) == 2: - pad_left = pad_right = padding[0] - pad_top = pad_bottom = padding[1] - if isinstance(padding, tuple) and len(padding) == 4: - pad_left = padding[0] - pad_top = padding[1] - pad_right = padding[2] - pad_bottom = padding[3] - - p = [pad_left, pad_top, pad_right, pad_bottom] - cropping = -np.minimum(p, 0) - - if cropping.any(): - crop_left, crop_top, crop_right, crop_bottom = cropping - img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom)) - - pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0) - - if img.mode == "P": - palette = img.getpalette() - img = np.asarray(img) - img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode) - img = Image.fromarray(img) - img.putpalette(palette) - return img - - img = np.asarray(img) - # RGB image - if len(img.shape) == 3: - img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode) - # Grayscale image - if len(img.shape) == 2: - img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) - - return Image.fromarray(img) - - -@torch.jit.unused -def crop( - img: Image.Image, - top: int, - left: int, - height: int, - width: int, -) -> Image.Image: - - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - return img.crop((left, top, left + width, top + height)) - - -@torch.jit.unused -def resize( - img: Image.Image, - size: Union[List[int], int], - interpolation: int = Image.BILINEAR, -) -> Image.Image: - - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - if not (isinstance(size, list) and len(size) == 2): - raise TypeError(f"Got inappropriate size arg: {size}") - - return img.resize(tuple(size[::-1]), interpolation) - - -@torch.jit.unused -def _parse_fill( - fill: Optional[Union[float, List[float], Tuple[float, ...]]], - img: Image.Image, - name: str = "fillcolor", -) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]: - - # Process fill color for affine transforms - num_channels = get_image_num_channels(img) - if fill is None: - fill = 0 - if isinstance(fill, (int, float)) and num_channels > 1: - fill = tuple([fill] * num_channels) - if isinstance(fill, (list, tuple)): - if len(fill) != num_channels: - msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})" - raise ValueError(msg.format(len(fill), num_channels)) - - fill = tuple(fill) - - if img.mode != "F": - if isinstance(fill, (list, tuple)): - fill = tuple(int(x) for x in fill) - else: - fill = int(fill) - - return {name: fill} - - -@torch.jit.unused -def affine( - img: Image.Image, - matrix: List[float], - interpolation: int = Image.NEAREST, - fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, -) -> Image.Image: - - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - output_size = img.size - opts = _parse_fill(fill, img) - return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts) - - -@torch.jit.unused -def rotate( - img: Image.Image, - angle: float, - interpolation: int = Image.NEAREST, - expand: bool = False, - center: Optional[Tuple[int, int]] = None, - fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, -) -> Image.Image: - - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - opts = _parse_fill(fill, img) - return img.rotate(angle, interpolation, expand, center, **opts) - - -@torch.jit.unused -def perspective( - img: Image.Image, - perspective_coeffs: List[float], - interpolation: int = Image.BICUBIC, - fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, -) -> Image.Image: - - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - opts = _parse_fill(fill, img) - - return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts) - - -@torch.jit.unused -def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image: - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - if num_output_channels == 1: - img = img.convert("L") - elif num_output_channels == 3: - img = img.convert("L") - np_img = np.array(img, dtype=np.uint8) - np_img = np.dstack([np_img, np_img, np_img]) - img = Image.fromarray(np_img, "RGB") - else: - raise ValueError("num_output_channels should be either 1 or 3") - - return img - - -@torch.jit.unused -def invert(img: Image.Image) -> Image.Image: - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - return ImageOps.invert(img) - - -@torch.jit.unused -def posterize(img: Image.Image, bits: int) -> Image.Image: - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - return ImageOps.posterize(img, bits) - - -@torch.jit.unused -def solarize(img: Image.Image, threshold: int) -> Image.Image: - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - return ImageOps.solarize(img, threshold) - - -@torch.jit.unused -def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image: - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - - enhancer = ImageEnhance.Sharpness(img) - img = enhancer.enhance(sharpness_factor) - return img - - -@torch.jit.unused -def autocontrast(img: Image.Image) -> Image.Image: - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - return ImageOps.autocontrast(img) - - -@torch.jit.unused -def equalize(img: Image.Image) -> Image.Image: - if not _is_pil_image(img): - raise TypeError(f"img should be PIL Image. Got {type(img)}") - return ImageOps.equalize(img) +warnings.warn( + "The torchvision.transforms.functional_pil module is deprecated " + "in 0.15 and will be **removed in 0.17**. Please don't rely on it. " + "You probably just need to use APIs in " + "torchvision.transforms.functional or in " + "torchvision.transforms.v2.functional." +) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index d0e7c17882b..a9f7f509074 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,962 +1,11 @@ import warnings -from typing import List, Optional, Tuple, Union -import torch -from torch import Tensor -from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad +from torchvision.transforms._functional_tensor import * # noqa - -def _is_tensor_a_torch_image(x: Tensor) -> bool: - return x.ndim >= 2 - - -def _assert_image_tensor(img: Tensor) -> None: - if not _is_tensor_a_torch_image(img): - raise TypeError("Tensor is not a torch image.") - - -def get_dimensions(img: Tensor) -> List[int]: - _assert_image_tensor(img) - channels = 1 if img.ndim == 2 else img.shape[-3] - height, width = img.shape[-2:] - return [channels, height, width] - - -def get_image_size(img: Tensor) -> List[int]: - # Returns (w, h) of tensor image - _assert_image_tensor(img) - return [img.shape[-1], img.shape[-2]] - - -def get_image_num_channels(img: Tensor) -> int: - _assert_image_tensor(img) - if img.ndim == 2: - return 1 - elif img.ndim > 2: - return img.shape[-3] - - raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}") - - -def _max_value(dtype: torch.dtype) -> int: - if dtype == torch.uint8: - return 255 - elif dtype == torch.int8: - return 127 - elif dtype == torch.int16: - return 32767 - elif dtype == torch.int32: - return 2147483647 - elif dtype == torch.int64: - return 9223372036854775807 - else: - # This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not - # easy. - return 1 - - -def _assert_channels(img: Tensor, permitted: List[int]) -> None: - c = get_dimensions(img)[0] - if c not in permitted: - raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}") - - -def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: - if image.dtype == dtype: - return image - - if image.is_floating_point(): - - # TODO: replace with dtype.is_floating_point when torchscript supports it - if torch.tensor(0, dtype=dtype).is_floating_point(): - return image.to(dtype) - - # float to int - if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( - image.dtype == torch.float64 and dtype == torch.int64 - ): - msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." - raise RuntimeError(msg) - - # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 - # For data in the range 0-1, (float * 255).to(uint) is only 255 - # when float is exactly 1.0. - # `max + 1 - epsilon` provides more evenly distributed mapping of - # ranges of floats to ints. - eps = 1e-3 - max_val = float(_max_value(dtype)) - result = image.mul(max_val + 1.0 - eps) - return result.to(dtype) - else: - input_max = float(_max_value(image.dtype)) - - # int to float - # TODO: replace with dtype.is_floating_point when torchscript supports it - if torch.tensor(0, dtype=dtype).is_floating_point(): - image = image.to(dtype) - return image / input_max - - output_max = float(_max_value(dtype)) - - # int to int - if input_max > output_max: - # factor should be forced to int for torch jit script - # otherwise factor is a float and image // factor can produce different results - factor = int((input_max + 1) // (output_max + 1)) - image = torch.div(image, factor, rounding_mode="floor") - return image.to(dtype) - else: - # factor should be forced to int for torch jit script - # otherwise factor is a float and image * factor can produce different results - factor = int((output_max + 1) // (input_max + 1)) - image = image.to(dtype) - return image * factor - - -def vflip(img: Tensor) -> Tensor: - _assert_image_tensor(img) - - return img.flip(-2) - - -def hflip(img: Tensor) -> Tensor: - _assert_image_tensor(img) - - return img.flip(-1) - - -def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: - _assert_image_tensor(img) - - _, h, w = get_dimensions(img) - right = left + width - bottom = top + height - - if left < 0 or top < 0 or right > w or bottom > h: - padding_ltrb = [ - max(-left + min(0, right), 0), - max(-top + min(0, bottom), 0), - max(right - max(w, left), 0), - max(bottom - max(h, top), 0), - ] - return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0) - return img[..., top:bottom, left:right] - - -def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: - if img.ndim < 3: - raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") - _assert_channels(img, [1, 3]) - - if num_output_channels not in (1, 3): - raise ValueError("num_output_channels should be either 1 or 3") - - if img.shape[-3] == 3: - r, g, b = img.unbind(dim=-3) - # This implementation closely follows the TF one: - # https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138 - l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype) - l_img = l_img.unsqueeze(dim=-3) - else: - l_img = img.clone() - - if num_output_channels == 3: - return l_img.expand(img.shape) - - return l_img - - -def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: - if brightness_factor < 0: - raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") - - _assert_image_tensor(img) - - _assert_channels(img, [1, 3]) - - return _blend(img, torch.zeros_like(img), brightness_factor) - - -def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: - if contrast_factor < 0: - raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") - - _assert_image_tensor(img) - - _assert_channels(img, [3, 1]) - c = get_dimensions(img)[0] - dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - if c == 3: - mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True) - else: - mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True) - - return _blend(img, mean, contrast_factor) - - -def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: - if not (-0.5 <= hue_factor <= 0.5): - raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") - - if not (isinstance(img, torch.Tensor)): - raise TypeError("Input img should be Tensor image") - - _assert_image_tensor(img) - - _assert_channels(img, [1, 3]) - if get_dimensions(img)[0] == 1: # Match PIL behaviour - return img - - orig_dtype = img.dtype - img = convert_image_dtype(img, torch.float32) - - img = _rgb2hsv(img) - h, s, v = img.unbind(dim=-3) - h = (h + hue_factor) % 1.0 - img = torch.stack((h, s, v), dim=-3) - img_hue_adj = _hsv2rgb(img) - - return convert_image_dtype(img_hue_adj, orig_dtype) - - -def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: - if saturation_factor < 0: - raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") - - _assert_image_tensor(img) - - _assert_channels(img, [1, 3]) - - if get_dimensions(img)[0] == 1: # Match PIL behaviour - return img - - return _blend(img, rgb_to_grayscale(img), saturation_factor) - - -def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: - if not isinstance(img, torch.Tensor): - raise TypeError("Input img should be a Tensor.") - - _assert_channels(img, [1, 3]) - - if gamma < 0: - raise ValueError("Gamma should be a non-negative real number") - - result = img - dtype = img.dtype - if not torch.is_floating_point(img): - result = convert_image_dtype(result, torch.float32) - - result = (gain * result**gamma).clamp(0, 1) - - result = convert_image_dtype(result, dtype) - return result - - -def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: - ratio = float(ratio) - bound = _max_value(img1.dtype) - return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype) - - -def _rgb2hsv(img: Tensor) -> Tensor: - r, g, b = img.unbind(dim=-3) - - # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/ - # src/libImaging/Convert.c#L330 - maxc = torch.max(img, dim=-3).values - minc = torch.min(img, dim=-3).values - - # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN - # from happening in the results, because - # + S channel has division by `maxc`, which is zero only if `maxc = minc` - # + H channel has division by `(maxc - minc)`. - # - # Instead of overwriting NaN afterwards, we just prevent it from occurring, so - # we don't need to deal with it in case we save the NaN in a buffer in - # backprop, if it is ever supported, but it doesn't hurt to do so. - eqc = maxc == minc - - cr = maxc - minc - # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine. - ones = torch.ones_like(maxc) - s = cr / torch.where(eqc, ones, maxc) - # Note that `eqc => maxc = minc = r = g = b`. So the following calculation - # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it - # would not matter what values `rc`, `gc`, and `bc` have here, and thus - # replacing denominator with 1 when `eqc` is fine. - cr_divisor = torch.where(eqc, ones, cr) - rc = (maxc - r) / cr_divisor - gc = (maxc - g) / cr_divisor - bc = (maxc - b) / cr_divisor - - hr = (maxc == r) * (bc - gc) - hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc) - hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) - h = hr + hg + hb - h = torch.fmod((h / 6.0 + 1.0), 1.0) - return torch.stack((h, s, maxc), dim=-3) - - -def _hsv2rgb(img: Tensor) -> Tensor: - h, s, v = img.unbind(dim=-3) - i = torch.floor(h * 6.0) - f = (h * 6.0) - i - i = i.to(dtype=torch.int32) - - p = torch.clamp((v * (1.0 - s)), 0.0, 1.0) - q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0) - t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0) - i = i % 6 - - mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1) - - a1 = torch.stack((v, q, p, p, t, v), dim=-3) - a2 = torch.stack((t, v, v, q, p, p), dim=-3) - a3 = torch.stack((p, p, t, v, v, q), dim=-3) - a4 = torch.stack((a1, a2, a3), dim=-4) - - return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4) - - -def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor: - # padding is left, right, top, bottom - - # crop if needed - if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0: - neg_min_padding = [-min(x, 0) for x in padding] - crop_left, crop_right, crop_top, crop_bottom = neg_min_padding - img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right] - padding = [max(x, 0) for x in padding] - - in_sizes = img.size() - - _x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...] - left_indices = [i for i in range(padding[0] - 1, -1, -1)] # e.g. [3, 2, 1, 0] - right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3] - x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device) - - _y_indices = [i for i in range(in_sizes[-2])] - top_indices = [i for i in range(padding[2] - 1, -1, -1)] - bottom_indices = [-(i + 1) for i in range(padding[3])] - y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device) - - ndim = img.ndim - if ndim == 3: - return img[:, y_indices[:, None], x_indices[None, :]] - elif ndim == 4: - return img[:, :, y_indices[:, None], x_indices[None, :]] - else: - raise RuntimeError("Symmetric padding of N-D tensors are not supported yet") - - -def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: - if isinstance(padding, int): - if torch.jit.is_scripting(): - # This maybe unreachable - raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]") - pad_left = pad_right = pad_top = pad_bottom = padding - elif len(padding) == 1: - pad_left = pad_right = pad_top = pad_bottom = padding[0] - elif len(padding) == 2: - pad_left = pad_right = padding[0] - pad_top = pad_bottom = padding[1] - else: - pad_left = padding[0] - pad_top = padding[1] - pad_right = padding[2] - pad_bottom = padding[3] - - return [pad_left, pad_right, pad_top, pad_bottom] - - -def pad( - img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant" -) -> Tensor: - _assert_image_tensor(img) - - if fill is None: - fill = 0 - - if not isinstance(padding, (int, tuple, list)): - raise TypeError("Got inappropriate padding arg") - if not isinstance(fill, (int, float)): - raise TypeError("Got inappropriate fill arg") - if not isinstance(padding_mode, str): - raise TypeError("Got inappropriate padding_mode arg") - - if isinstance(padding, tuple): - padding = list(padding) - - if isinstance(padding, list): - # TODO: Jit is failing on loading this op when scripted and saved - # https://github.com/pytorch/pytorch/issues/81100 - if len(padding) not in [1, 2, 4]: - raise ValueError( - f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" - ) - - if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: - raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") - - p = _parse_pad_padding(padding) - - if padding_mode == "edge": - # remap padding_mode str - padding_mode = "replicate" - elif padding_mode == "symmetric": - # route to another implementation - return _pad_symmetric(img, p) - - need_squeeze = False - if img.ndim < 4: - img = img.unsqueeze(dim=0) - need_squeeze = True - - out_dtype = img.dtype - need_cast = False - if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64): - # Here we temporarily cast input tensor to float - # until pytorch issue is resolved : - # https://github.com/pytorch/pytorch/issues/40763 - need_cast = True - img = img.to(torch.float32) - - if padding_mode in ("reflect", "replicate"): - img = torch_pad(img, p, mode=padding_mode) - else: - img = torch_pad(img, p, mode=padding_mode, value=float(fill)) - - if need_squeeze: - img = img.squeeze(dim=0) - - if need_cast: - img = img.to(out_dtype) - - return img - - -def resize( - img: Tensor, - size: List[int], - interpolation: str = "bilinear", - # TODO: in v0.17, change the default to True. This will a private function - # by then, so we don't care about warning here. - antialias: Optional[bool] = None, -) -> Tensor: - _assert_image_tensor(img) - - if isinstance(size, tuple): - size = list(size) - - if antialias is None: - antialias = False - - if antialias and interpolation not in ["bilinear", "bicubic"]: - # We manually set it to False to avoid an error downstream in interpolate() - # This behaviour is documented: the parameter is irrelevant for modes - # that are not bilinear or bicubic. We used to raise an error here, but - # now we don't as True is the default. - antialias = False - - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) - - # Define align_corners to avoid warnings - align_corners = False if interpolation in ["bilinear", "bicubic"] else None - - img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias) - - if interpolation == "bicubic" and out_dtype == torch.uint8: - img = img.clamp(min=0, max=255) - - img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype) - - return img - - -def _assert_grid_transform_inputs( - img: Tensor, - matrix: Optional[List[float]], - interpolation: str, - fill: Optional[Union[int, float, List[float]]], - supported_interpolation_modes: List[str], - coeffs: Optional[List[float]] = None, -) -> None: - - if not (isinstance(img, torch.Tensor)): - raise TypeError("Input img should be Tensor") - - _assert_image_tensor(img) - - if matrix is not None and not isinstance(matrix, list): - raise TypeError("Argument matrix should be a list") - - if matrix is not None and len(matrix) != 6: - raise ValueError("Argument matrix should have 6 float values") - - if coeffs is not None and len(coeffs) != 8: - raise ValueError("Argument coeffs should have 8 float values") - - if fill is not None and not isinstance(fill, (int, float, tuple, list)): - warnings.warn("Argument fill should be either int, float, tuple or list") - - # Check fill - num_channels = get_dimensions(img)[0] - if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels: - msg = ( - "The number of elements in 'fill' cannot broadcast to match the number of " - "channels of the image ({} != {})" - ) - raise ValueError(msg.format(len(fill), num_channels)) - - if interpolation not in supported_interpolation_modes: - raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input") - - -def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]: - need_squeeze = False - # make image NCHW - if img.ndim < 4: - img = img.unsqueeze(dim=0) - need_squeeze = True - - out_dtype = img.dtype - need_cast = False - if out_dtype not in req_dtypes: - need_cast = True - req_dtype = req_dtypes[0] - img = img.to(req_dtype) - return img, need_cast, need_squeeze, out_dtype - - -def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor: - if need_squeeze: - img = img.squeeze(dim=0) - - if need_cast: - if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): - # it is better to round before cast - img = torch.round(img) - img = img.to(out_dtype) - - return img - - -def _apply_grid_transform( - img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]] -) -> Tensor: - - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype]) - - if img.shape[0] > 1: - # Apply same grid to a batch of images - grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3]) - - # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice - if fill is not None: - mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device) - img = torch.cat((img, mask), dim=1) - - img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False) - - # Fill with required color - if fill is not None: - mask = img[:, -1:, :, :] # N * 1 * H * W - img = img[:, :-1, :, :] # N * C * H * W - mask = mask.expand_as(img) - fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1) - fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img) - if mode == "nearest": - mask = mask < 0.5 - img[mask] = fill_img[mask] - else: # 'bilinear' - img = img * mask + (1.0 - mask) * fill_img - - img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) - return img - - -def _gen_affine_grid( - theta: Tensor, - w: int, - h: int, - ow: int, - oh: int, -) -> Tensor: - # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/ - # AffineGridGenerator.cpp#L18 - # Difference with AffineGridGenerator is that: - # 1) we normalize grid values after applying theta - # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate - - d = 0.5 - base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device) - x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device) - base_grid[..., 0].copy_(x_grid) - y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1) - base_grid[..., 1].copy_(y_grid) - base_grid[..., 2].fill_(1) - - rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device) - output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta) - return output_grid.view(1, oh, ow, 2) - - -def affine( - img: Tensor, - matrix: List[float], - interpolation: str = "nearest", - fill: Optional[Union[int, float, List[float]]] = None, -) -> Tensor: - _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) - - dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) - shape = img.shape - # grid will be generated on the same device as theta and img - grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) - return _apply_grid_transform(img, grid, interpolation, fill=fill) - - -def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: - - # Inspired of PIL implementation: - # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 - - # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. - # Points are shifted due to affine matrix torch convention about - # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5) - pts = torch.tensor( - [ - [-0.5 * w, -0.5 * h, 1.0], - [-0.5 * w, 0.5 * h, 1.0], - [0.5 * w, 0.5 * h, 1.0], - [0.5 * w, -0.5 * h, 1.0], - ] - ) - theta = torch.tensor(matrix, dtype=torch.float).view(2, 3) - new_pts = torch.matmul(pts, theta.T) - min_vals, _ = new_pts.min(dim=0) - max_vals, _ = new_pts.max(dim=0) - - # shift points to [0, w] and [0, h] interval to match PIL results - min_vals += torch.tensor((w * 0.5, h * 0.5)) - max_vals += torch.tensor((w * 0.5, h * 0.5)) - - # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0 - tol = 1e-4 - cmax = torch.ceil((max_vals / tol).trunc_() * tol) - cmin = torch.floor((min_vals / tol).trunc_() * tol) - size = cmax - cmin - return int(size[0]), int(size[1]) # w, h - - -def rotate( - img: Tensor, - matrix: List[float], - interpolation: str = "nearest", - expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, -) -> Tensor: - _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) - w, h = img.shape[-1], img.shape[-2] - ow, oh = _compute_affine_output_size(matrix, w, h) if expand else (w, h) - dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) - # grid will be generated on the same device as theta and img - grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) - - return _apply_grid_transform(img, grid, interpolation, fill=fill) - - -def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor: - # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/ - # src/libImaging/Geometry.c#L394 - - # - # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) - # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) - # - theta1 = torch.tensor( - [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device - ) - theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device) - - d = 0.5 - base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) - x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device) - base_grid[..., 0].copy_(x_grid) - y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1) - base_grid[..., 1].copy_(y_grid) - base_grid[..., 2].fill_(1) - - rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device) - output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1) - output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2)) - - output_grid = output_grid1 / output_grid2 - 1.0 - return output_grid.view(1, oh, ow, 2) - - -def perspective( - img: Tensor, - perspective_coeffs: List[float], - interpolation: str = "bilinear", - fill: Optional[Union[int, float, List[float]]] = None, -) -> Tensor: - if not (isinstance(img, torch.Tensor)): - raise TypeError("Input img should be Tensor.") - - _assert_image_tensor(img) - - _assert_grid_transform_inputs( - img, - matrix=None, - interpolation=interpolation, - fill=fill, - supported_interpolation_modes=["nearest", "bilinear"], - coeffs=perspective_coeffs, - ) - - ow, oh = img.shape[-1], img.shape[-2] - dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device) - return _apply_grid_transform(img, grid, interpolation, fill=fill) - - -def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: - ksize_half = (kernel_size - 1) * 0.5 - - x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) - pdf = torch.exp(-0.5 * (x / sigma).pow(2)) - kernel1d = pdf / pdf.sum() - - return kernel1d - - -def _get_gaussian_kernel2d( - kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device -) -> Tensor: - kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype) - kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype) - kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :]) - return kernel2d - - -def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor: - if not (isinstance(img, torch.Tensor)): - raise TypeError(f"img should be Tensor. Got {type(img)}") - - _assert_image_tensor(img) - - dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device) - kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) - - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype]) - - # padding = (left, right, top, bottom) - padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] - img = torch_pad(img, padding, mode="reflect") - img = conv2d(img, kernel, groups=img.shape[-3]) - - img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) - return img - - -def invert(img: Tensor) -> Tensor: - - _assert_image_tensor(img) - - if img.ndim < 3: - raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") - - _assert_channels(img, [1, 3]) - - return _max_value(img.dtype) - img - - -def posterize(img: Tensor, bits: int) -> Tensor: - - _assert_image_tensor(img) - - if img.ndim < 3: - raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") - if img.dtype != torch.uint8: - raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}") - - _assert_channels(img, [1, 3]) - mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1) - return img & mask - - -def solarize(img: Tensor, threshold: float) -> Tensor: - - _assert_image_tensor(img) - - if img.ndim < 3: - raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") - - _assert_channels(img, [1, 3]) - - if threshold > _max_value(img.dtype): - raise TypeError("Threshold should be less than bound of img.") - - inverted_img = invert(img) - return torch.where(img >= threshold, inverted_img, img) - - -def _blurred_degenerate_image(img: Tensor) -> Tensor: - dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - - kernel = torch.ones((3, 3), dtype=dtype, device=img.device) - kernel[1, 1] = 5.0 - kernel /= kernel.sum() - kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) - - result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype]) - result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3]) - result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype) - - result = img.clone() - result[..., 1:-1, 1:-1] = result_tmp - - return result - - -def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: - if sharpness_factor < 0: - raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.") - - _assert_image_tensor(img) - - _assert_channels(img, [1, 3]) - - if img.size(-1) <= 2 or img.size(-2) <= 2: - return img - - return _blend(img, _blurred_degenerate_image(img), sharpness_factor) - - -def autocontrast(img: Tensor) -> Tensor: - - _assert_image_tensor(img) - - if img.ndim < 3: - raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") - - _assert_channels(img, [1, 3]) - - bound = _max_value(img.dtype) - dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - - minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype) - maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype) - scale = bound / (maximum - minimum) - eq_idxs = torch.isfinite(scale).logical_not() - minimum[eq_idxs] = 0 - scale[eq_idxs] = 1 - - return ((img - minimum) * scale).clamp(0, bound).to(img.dtype) - - -def _scale_channel(img_chan: Tensor) -> Tensor: - # TODO: we should expect bincount to always be faster than histc, but this - # isn't always the case. Once - # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if - # block and only use bincount. - if img_chan.is_cuda: - hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255) - else: - hist = torch.bincount(img_chan.reshape(-1), minlength=256) - - nonzero_hist = hist[hist != 0] - step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor") - if step == 0: - return img_chan - - lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor") - lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255) - - return lut[img_chan.to(torch.int64)].to(torch.uint8) - - -def _equalize_single_image(img: Tensor) -> Tensor: - return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))]) - - -def equalize(img: Tensor) -> Tensor: - - _assert_image_tensor(img) - - if not (3 <= img.ndim <= 4): - raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}") - if img.dtype != torch.uint8: - raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}") - - _assert_channels(img, [1, 3]) - - if img.ndim == 3: - return _equalize_single_image(img) - - return torch.stack([_equalize_single_image(x) for x in img]) - - -def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor: - _assert_image_tensor(tensor) - - if not tensor.is_floating_point(): - raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.") - - if tensor.ndim < 3: - raise ValueError( - f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}" - ) - - if not inplace: - tensor = tensor.clone() - - dtype = tensor.dtype - mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) - std = torch.as_tensor(std, dtype=dtype, device=tensor.device) - if (std == 0).any(): - raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.") - if mean.ndim == 1: - mean = mean.view(-1, 1, 1) - if std.ndim == 1: - std = std.view(-1, 1, 1) - return tensor.sub_(mean).div_(std) - - -def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: - _assert_image_tensor(img) - - if not inplace: - img = img.clone() - - img[..., i : i + h, j : j + w] = v - return img - - -def _create_identity_grid(size: List[int]) -> Tensor: - hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size] - grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij") - return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2 - - -def elastic_transform( - img: Tensor, - displacement: Tensor, - interpolation: str = "bilinear", - fill: Optional[Union[int, float, List[float]]] = None, -) -> Tensor: - - if not (isinstance(img, torch.Tensor)): - raise TypeError(f"img should be Tensor. Got {type(img)}") - - size = list(img.shape[-2:]) - displacement = displacement.to(img.device) - - identity_grid = _create_identity_grid(size) - grid = identity_grid.to(img.device) + displacement - return _apply_grid_transform(img, grid, interpolation, fill) +warnings.warn( + "The torchvision.transforms.functional_tensor module is deprecated " + "in 0.15 and will be **removed in 0.17**. Please don't rely on it. " + "You probably just need to use APIs in " + "torchvision.transforms.functional or in " + "torchvision.transforms.v2.functional." +) diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index bdc3c89d7f3..fcd9c758ce4 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -6,7 +6,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec from torchvision import datapoints, transforms as _transforms -from torchvision.transforms import functional_tensor as _FT +from torchvision.transforms import _functional_tensor as _FT from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.transforms.v2.functional._geometry import _check_interpolation from torchvision.transforms.v2.functional._meta import get_spatial_size diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 2ebb4f044cb..cf8d73dfff5 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -4,8 +4,8 @@ import torch from torch.nn.functional import conv2d from torchvision import datapoints -from torchvision.transforms import functional_pil as _FP -from torchvision.transforms.functional_tensor import _max_value +from torchvision.transforms import _functional_pil as _FP +from torchvision.transforms._functional_tensor import _max_value from torchvision.utils import _log_api_usage_once diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index c48250f3b96..baa2f24f21a 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -8,7 +8,8 @@ from torch.nn.functional import grid_sample, interpolate, pad as torch_pad from torchvision import datapoints -from torchvision.transforms import functional_pil as _FP +from torchvision.transforms import _functional_pil as _FP +from torchvision.transforms._functional_tensor import _pad_symmetric from torchvision.transforms.functional import ( _check_antialias, _compute_resized_output_size as __compute_resized_output_size, @@ -19,7 +20,6 @@ pil_to_tensor, to_pil_image, ) -from torchvision.transforms.functional_tensor import _pad_symmetric from torchvision.utils import _log_api_usage_once diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index c61f7a710d4..dc96838ddf1 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -4,8 +4,8 @@ import torch from torchvision import datapoints from torchvision.datapoints import BoundingBoxFormat -from torchvision.transforms import functional_pil as _FP -from torchvision.transforms.functional_tensor import _max_value +from torchvision.transforms import _functional_pil as _FP +from torchvision.transforms._functional_tensor import _max_value from torchvision.utils import _log_api_usage_once