Skip to content

Commit

Permalink
Put all dtype conversion stuff into the _misc namespaces (#7770)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Jul 28, 2023
1 parent 3591371 commit 72dcc17
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 144 deletions.
13 changes: 11 additions & 2 deletions torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,17 @@
ScaleJitter,
TenCrop,
)
from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBox, ToDtype
from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat
from ._misc import (
ConvertImageDtype,
GaussianBlur,
Identity,
Lambda,
LinearTransformation,
Normalize,
SanitizeBoundingBox,
ToDtype,
)
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage

Expand Down
43 changes: 1 addition & 42 deletions torchvision/transforms/v2/_meta.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
from typing import Any, Dict, Union

import torch

from torchvision import datapoints, transforms as _transforms
from torchvision import datapoints
from torchvision.transforms.v2 import functional as F, Transform

from .utils import is_simple_tensor


class ConvertBoundingBoxFormat(Transform):
"""[BETA] Convert bounding box coordinates to the given ``format``, eg from "CXCYWH" to "XYXY".
Expand All @@ -31,43 +27,6 @@ def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> da
return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value]


class ConvertImageDtype(Transform):
"""[BETA] Convert input image to the given ``dtype`` and scale the values accordingly.
.. v2betastatus:: ConvertImageDtype transform
.. warning::
Consider using ``ToDtype(dtype, scale=True)`` instead. See :class:`~torchvision.transforms.v2.ToDtype`.
This function does not support PIL Image.
Args:
dtype (torch.dtype): Desired data type of the output
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""

_v1_transform_cls = _transforms.ConvertImageDtype

_transformed_types = (is_simple_tensor, datapoints.Image)

def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.to_dtype(inpt, dtype=self.dtype, scale=True)


class ClampBoundingBox(Transform):
"""[BETA] Clamp bounding boxes to their corresponding image dimensions.
Expand Down
37 changes: 37 additions & 0 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,43 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.to_dtype(inpt, dtype=dtype, scale=self.scale)


class ConvertImageDtype(Transform):
"""[BETA] Convert input image to the given ``dtype`` and scale the values accordingly.
.. v2betastatus:: ConvertImageDtype transform
.. warning::
Consider using ``ToDtype(dtype, scale=True)`` instead. See :class:`~torchvision.transforms.v2.ToDtype`.
This function does not support PIL Image.
Args:
dtype (torch.dtype): Desired data type of the output
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""

_v1_transform_cls = _transforms.ConvertImageDtype

_transformed_types = (is_simple_tensor, datapoints.Image)

def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.to_dtype(inpt, dtype=self.dtype, scale=True)


class SanitizeBoundingBox(Transform):
"""[BETA] Remove degenerate/invalid bounding boxes and their corresponding labels and masks.
Expand Down
8 changes: 4 additions & 4 deletions torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
from ._meta import (
clamp_bounding_box,
convert_format_bounding_box,
convert_image_dtype,
to_dtype,
to_dtype_image_tensor,
to_dtype_video,
get_dimensions_image_tensor,
get_dimensions_image_pil,
get_dimensions,
Expand Down Expand Up @@ -158,13 +154,17 @@
vflip,
)
from ._misc import (
convert_image_dtype,
gaussian_blur,
gaussian_blur_image_pil,
gaussian_blur_image_tensor,
gaussian_blur_video,
normalize,
normalize_image_tensor,
normalize_video,
to_dtype,
to_dtype_image_tensor,
to_dtype_video,
)
from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video
from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from torchvision.utils import _log_api_usage_once

from ._meta import _num_value_bits, to_dtype_image_tensor
from ._misc import _num_value_bits, to_dtype_image_tensor
from ._utils import is_simple_tensor


Expand Down
95 changes: 0 additions & 95 deletions torchvision/transforms/v2/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torchvision import datapoints
from torchvision.datapoints import BoundingBoxFormat
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _max_value

from torchvision.utils import _log_api_usage_once

Expand Down Expand Up @@ -279,97 +278,3 @@ def clamp_bounding_box(
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
)


def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
elif dtype == torch.int8:
return 7
elif dtype == torch.int16:
return 15
elif dtype == torch.int32:
return 31
elif dtype == torch.int64:
return 63
else:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")


def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:

if image.dtype == dtype:
return image
elif not scale:
return image.to(dtype)

float_input = image.is_floating_point()
if torch.jit.is_scripting():
# TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
float_output = torch.tensor(0, dtype=dtype).is_floating_point()
else:
float_output = dtype.is_floating_point

if float_input:
# float to float
if float_output:
return image.to(dtype)

# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")

# For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
# to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
# be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# for a detailed analysis.
# To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
# Instead, we can also multiply by the maximum value plus something close to `1`. See
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
eps = 1e-3
max_value = float(_max_value(dtype))
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
# discrete set `{0, 1}`.
return image.mul(max_value + 1.0 - eps).to(dtype)
else:
# int to float
if float_output:
return image.to(dtype).mul_(1.0 / _max_value(image.dtype))

# int to int
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)

if num_value_bits_input > num_value_bits_output:
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
else:
return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)


# We encourage users to use to_dtype() instead but we keep this for BC
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
return to_dtype_image_tensor(image, dtype=dtype, scale=True)


def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
return to_dtype_image_tensor(video, dtype, scale=scale)


def to_dtype(inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(to_dtype)

if torch.jit.is_scripting() or is_simple_tensor(inpt):
return to_dtype_image_tensor(inpt, dtype, scale=scale)
elif isinstance(inpt, datapoints.Image):
output = to_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, datapoints.Video):
output = to_dtype_video(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Video.wrap_like(inpt, output)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.to(dtype)
else:
raise TypeError(f"Input can either be a plain tensor or a datapoint, but got {type(inpt)} instead.")
95 changes: 95 additions & 0 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.nn.functional import conv2d, pad as torch_pad

from torchvision import datapoints
from torchvision.transforms._functional_tensor import _max_value
from torchvision.transforms.functional import pil_to_tensor, to_pil_image

from torchvision.utils import _log_api_usage_once
Expand Down Expand Up @@ -182,3 +183,97 @@ def gaussian_blur(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)


def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
elif dtype == torch.int8:
return 7
elif dtype == torch.int16:
return 15
elif dtype == torch.int32:
return 31
elif dtype == torch.int64:
return 63
else:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")


def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:

if image.dtype == dtype:
return image
elif not scale:
return image.to(dtype)

float_input = image.is_floating_point()
if torch.jit.is_scripting():
# TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
float_output = torch.tensor(0, dtype=dtype).is_floating_point()
else:
float_output = dtype.is_floating_point

if float_input:
# float to float
if float_output:
return image.to(dtype)

# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")

# For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
# to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
# be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# for a detailed analysis.
# To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
# Instead, we can also multiply by the maximum value plus something close to `1`. See
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
eps = 1e-3
max_value = float(_max_value(dtype))
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
# discrete set `{0, 1}`.
return image.mul(max_value + 1.0 - eps).to(dtype)
else:
# int to float
if float_output:
return image.to(dtype).mul_(1.0 / _max_value(image.dtype))

# int to int
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)

if num_value_bits_input > num_value_bits_output:
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
else:
return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)


# We encourage users to use to_dtype() instead but we keep this for BC
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
return to_dtype_image_tensor(image, dtype=dtype, scale=True)


def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
return to_dtype_image_tensor(video, dtype, scale=scale)


def to_dtype(inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(to_dtype)

if torch.jit.is_scripting() or is_simple_tensor(inpt):
return to_dtype_image_tensor(inpt, dtype, scale=scale)
elif isinstance(inpt, datapoints.Image):
output = to_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, datapoints.Video):
output = to_dtype_video(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Video.wrap_like(inpt, output)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.to(dtype)
else:
raise TypeError(f"Input can either be a plain tensor or a datapoint, but got {type(inpt)} instead.")

0 comments on commit 72dcc17

Please sign in to comment.