diff --git a/pyproject.toml b/pyproject.toml index 6c3f7a7..29012ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ sections= ["FUTURE", "STDLIB", "THIRDPARTY", "PYTORCH", "FIRSTPARTY", "LOCALFOLD skip = [ "torchvision/datasets/__init__.py", + "torchvision/transforms/__init__.py", ] [tool.black] @@ -36,9 +37,3 @@ skip = [ line-length = 120 target-version = ["py36"] -exclude = ''' -/( - \.git - | __pycache__ -)/ -''' \ No newline at end of file diff --git a/references/segmentation/main.py b/references/segmentation/main.py new file mode 100644 index 0000000..d1fac1f --- /dev/null +++ b/references/segmentation/main.py @@ -0,0 +1,15 @@ +from transforms import get_transform + +import torch + +from torchvision.features import Image, Segmentation + +image = Image(torch.rand(3, 480, 640)) +seg = Segmentation(torch.randint(0, 256, size=image.shape, dtype=torch.uint8)) +sample = image, seg + +transform = get_transform(train=True) +train_image, train_seg = transform(sample) + +transform = get_transform(train=False) +eval_image, eval_seg = transform(sample) diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py new file mode 100644 index 0000000..dc034ae --- /dev/null +++ b/references/segmentation/transforms.py @@ -0,0 +1,30 @@ +from typing import Sequence + +from torchvision import transforms as T + + +def get_transform( + *, + train: bool, + base_size: int = 520, + crop_size: int = 480, + horizontal_flip_probability: float = 0.5, + mean: Sequence[float] = (0.485, 0.456, 0.406), + std: Sequence[float] = (0.229, 0.224, 0.225), +): + + if train: + min_size = base_size // 2 + max_size = base_size * 2 + transforms = [T.RandomResize(min_size, max_size)] + + if horizontal_flip_probability > 0: + transforms.append(T.RandomApply(T.HorizontalFlip(), p=horizontal_flip_probability)) + + transforms.append(T.RandomCrop(crop_size)) + + augmentation = T.Compose(*transforms) + else: + augmentation = T.Resize(base_size) + + return T.Compose(augmentation, T.Normalize(mean, std)) diff --git a/torchvision/datasets/utils/__init__.py b/torchvision/datasets/utils/__init__.py index 885fac0..cfad362 100644 --- a/torchvision/datasets/utils/__init__.py +++ b/torchvision/datasets/utils/__init__.py @@ -1,2 +1,3 @@ from ._bunch import * +from ._query import * from ._resource import * diff --git a/torchvision/datasets/utils/_query.py b/torchvision/datasets/utils/_query.py new file mode 100644 index 0000000..6362ae4 --- /dev/null +++ b/torchvision/datasets/utils/_query.py @@ -0,0 +1,62 @@ +import collections.abc +from typing import Any, Callable, Iterator, Optional, Set, Tuple, TypeVar, Union + +import torch + +from torchvision.features import BoundingBox, Image + +T = TypeVar("T") + +__all__ = ["Query"] + + +class Query: + def __init__(self, sample: Any) -> None: + self.sample = sample + + @staticmethod + def _query_recursively(sample: Any, fn: Callable[[Any], Optional[T]]) -> Iterator[T]: + if isinstance(sample, (collections.abc.Sequence, collections.abc.Mapping)): + for item in sample.values() if isinstance(sample, collections.abc.Mapping) else sample: + yield from Query._query_recursively(item, fn) + else: + result = fn(sample) + if result is not None: + yield result + + def query(self, fn: Callable[[Any], Optional[T]], *, unique: bool = True) -> Union[T, Set[T]]: + results = set(self._query_recursively(self.sample, fn)) + if not results: + raise RuntimeError("Query turned up empty.") + + if not unique: + return results + + if len(results) > 1: + raise RuntimeError(f"Found more than one result: {sorted(results)}") + + return results.pop() + + def image_size(self) -> Optional[Tuple[int, int]]: + def fn(sample: Any) -> Optional[Tuple[int, int]]: + if not isinstance(sample, torch.Tensor): + return None + elif type(sample) is torch.Tensor: + return sample.shape[-2:] + elif isinstance(sample, (Image, BoundingBox)): + return sample.image_size + else: + return None + + return self.query(fn) + + def batch_size(self) -> Optional[int]: + def fn(sample: Any) -> Optional[int]: + if not isinstance(sample, torch.Tensor): + return None + elif isinstance(sample, Image): + return sample.batch_size + else: + return None + + return self.query(fn) diff --git a/torchvision/features/__init__.py b/torchvision/features/__init__.py new file mode 100644 index 0000000..8c6110d --- /dev/null +++ b/torchvision/features/__init__.py @@ -0,0 +1,3 @@ +from ._bounding_box import * +from ._core import * +from ._image import * diff --git a/torchvision/features/_bounding_box.py b/torchvision/features/_bounding_box.py new file mode 100644 index 0000000..59d8a6c --- /dev/null +++ b/torchvision/features/_bounding_box.py @@ -0,0 +1,163 @@ +import enum +from typing import Any, Optional, Tuple, Union + +import torch + +from ._core import TensorFeature + +__all__ = ["BoundingBox", "BoundingBoxFormat"] + + +class BoundingBoxFormat(enum.Enum): + XYXY = "XYXY" + XYWH = "XYWH" + CXCYWH = "CXCYWH" + + +class BoundingBox(TensorFeature): + formats = BoundingBoxFormat + + @staticmethod + def _parse_format(format: Union[str, BoundingBoxFormat]) -> BoundingBoxFormat: + if isinstance(format, str): + format = format.upper() + return BoundingBox.formats(format) + + def __init__( + self, + data: Any = None, + *, + image_size: Tuple[int, int], + format: Union[str, BoundingBoxFormat], + ): + super().__init__() + self._image_size = image_size + self._format = self._parse_format(format) + + self._convert_to_xyxy = { + self.formats.XYWH: self._xywh_to_xyxy, + self.formats.CXCYWH: self._cxcywh_to_xyxy, + } + self._convert_from_xyxy = { + self.formats.XYWH: self._xyxy_to_xywh, + self.formats.CXCYWH: self._xyxy_to_cxcywh, + } + + def __new__( + cls, + data: Any = None, + *, + image_size: Tuple[int, int], + format: Union[str, BoundingBoxFormat], + ): + # Since torch.Tensor defines both __new__ and __init__, we also need to do that since we change the signature + return super().__new__(cls, data) + + @classmethod + def from_tensor( + cls, + tensor: torch.Tensor, + *, + like: Optional["BoundingBox"] = None, + image_size: Optional[Tuple[int, int]] = None, + format: Optional[Union[str, BoundingBoxFormat]] = None, + ) -> "BoundingBox": + params = cls._parse_from_tensor_args(like=like, image_size=image_size, format=format) + + format = params.get("format") or "xyxy" + + image_size = params.get("image_size") + if image_size is None: + # TODO: compute minimum image size needed to hold this bounding box depending on format + image_size = (0, 0) + + return cls(tensor, image_size=image_size, format=format) + + @property + def image_size(self) -> Tuple[int, int]: + return self._image_size + + @property + def format(self) -> BoundingBoxFormat: + return self._format + + @classmethod + def from_parts( + cls, + a, + b, + c, + d, + *, + format: Union[str, BoundingBoxFormat], + like: Optional["BoundingBox"] = None, + image_size: Optional[Tuple[int, int]] = None, + ) -> "BoundingBox": + parts = torch.broadcast_tensors( + *[part if isinstance(part, torch.Tensor) else torch.as_tensor(part) for part in (a, b, c, d)] + ) + return cls.from_tensor(torch.stack(parts, dim=-1), like=like, image_size=image_size, format=format) + + def to_parts(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return self.unbind(-1) + + def convert(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox": + format = self._parse_format(format) + # FIXME: cloning does not preserve custom attributes such as image_size or format + # bounding_box = self.clone() + bounding_box = self + + if format == self.format: + return bounding_box + + if self.format != self.formats.XYXY: + bounding_box = self._convert_to_xyxy[self.format](bounding_box) + + if format != self.formats.XYXY: + bounding_box = self._convert_from_xyxy[format](bounding_box) + + return bounding_box + + @staticmethod + def _xywh_to_xyxy(input: "BoundingBox") -> "BoundingBox": + x, y, w, h = input.to_parts() + + x1 = x + y1 = y + x2 = x + w + y2 = y + h + + return BoundingBox.from_parts(x1, y1, x2, y2, like=input, format="xyxy") + + @staticmethod + def _xyxy_to_xywh(input: "BoundingBox") -> "BoundingBox": + x1, y1, x2, y2 = input.to_parts() + + x = x1 + y = y1 + w = x2 - x1 + h = y2 - y1 + + return BoundingBox.from_parts(x, y, w, h, format="xywh") + + @staticmethod + def _cxcywh_to_xyxy(input: "BoundingBox") -> "BoundingBox": + cx, cy, w, h = input.to_parts() + + x1 = cx - 0.5 * w + y1 = cy - 0.5 * h + x2 = cx + 0.5 * w + y2 = cy + 0.5 * h + + return BoundingBox.from_parts(x1, y1, x2, y2, like=input, format="xyxy") + + @staticmethod + def _xyxy_to_cxcywh(input: "BoundingBox") -> "BoundingBox": + x1, y1, x2, y2 = input.to_parts() + + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + w = x2 - x1 + h = y2 - y1 + + return BoundingBox.from_parts(cx, cy, w, h, like=input, format="cxcywh") diff --git a/torchvision/features/_core.py b/torchvision/features/_core.py new file mode 100644 index 0000000..bf9864b --- /dev/null +++ b/torchvision/features/_core.py @@ -0,0 +1,33 @@ +from typing import Any, Dict, Optional, Type, TypeVar + +import torch + +__all__ = ["Feature", "TensorFeature"] + +TF = TypeVar("TF", bound="TensorFeature") + + +# A Feature might not necessarily be a Tensor. Think text. +class Feature: + pass + + +class TensorFeature(torch.Tensor, Feature): + def __new__(cls, data: Any = None): + if data is None: + data = torch.tensor([]) + requires_grad = False + return torch.Tensor._make_subclass(cls, data, requires_grad) + + @staticmethod + def _parse_from_tensor_args(*, like: Optional["TensorFeature"], **attrs: Any) -> Dict[str, Any]: + if not attrs: + raise ValueError() + + params = {name: getattr(like, name) for name in attrs.keys()} if like is not None else {} + params.update({name: value for name, value in attrs.items() if value is not None}) + return params + + @classmethod + def from_tensor(cls: Type[TF], tensor: torch.Tensor, *, like: Optional[TF] = None) -> TF: + return cls(tensor) diff --git a/torchvision/features/_image.py b/torchvision/features/_image.py new file mode 100644 index 0000000..001429a --- /dev/null +++ b/torchvision/features/_image.py @@ -0,0 +1,33 @@ +from typing import Tuple + +from ._core import TensorFeature + +__all__ = ["Image", "Segmentation"] + + +class Image(TensorFeature): + @property + def image_size(self) -> Tuple[int, int]: + return self.shape[-2:] + + @property + def batch_size(self) -> int: + return self.shape[0] if self.ndim == 4 else 0 + + def batch(self) -> "Image": + if self.batch_size > 0: + return self + + return Image.from_tensor(self.unsqueeze(0), like=self) + + def unbatch(self) -> "Image": + if self.batch_size == 0: + return self + elif self.batch_size == 1: + return Image.from_tensor(self.squeeze(0), like=self) + else: + raise RuntimeError("Cannot unbatch an image tensor if batch contains more than one image.") + + +class Segmentation(Image): + pass diff --git a/torchvision/transforms/__init__.py b/torchvision/transforms/__init__.py new file mode 100644 index 0000000..fcb6da1 --- /dev/null +++ b/torchvision/transforms/__init__.py @@ -0,0 +1,6 @@ +from ._transform import * + +from ._augmentation import * +from ._container import * +from ._geometry import * +from ._misc import * diff --git a/torchvision/transforms/_augmentation.py b/torchvision/transforms/_augmentation.py new file mode 100644 index 0000000..7446c8e --- /dev/null +++ b/torchvision/transforms/_augmentation.py @@ -0,0 +1,42 @@ +from typing import Any, Dict + +import torch + +from torchvision.datasets.utils import Query +from torchvision.features import Image +from torchvision.transforms.utils import ImageRequirement + +from ._transform import Transform + +__all__ = ["MixUp"] + + +class MixUp(Transform): + def __init__(self, alpha: float = 0.5) -> None: + super().__init__() + self.alpha = alpha + self._dist = torch.distributions.Beta(alpha, alpha) + + def get_params(self, sample: Any) -> Dict[str, Any]: + perm = torch.randperm(Query(sample).batch_size() or 1) + lam = self._dist.sample() + return dict(perm=perm, lam=lam) + + @staticmethod + def _is_ordered(perm) -> bool: + return bool(perm.eq(perm.sort().values).all()) + + @staticmethod + @ImageRequirement.batched(noop_if_single=True) + def image(input: Image, *, perm: torch.Tensor, lam: torch.Tensor) -> Image: + if MixUp._is_ordered(perm): + return input + + shuffled = input[perm] + mixed = lam * input + (1 - lam) * shuffled + return Image.from_tensor(mixed, like=input) + + @staticmethod + def label(input, *, perm: torch.Tensor, lam: torch.Tensor): + # TODO + pass diff --git a/torchvision/transforms/_container.py b/torchvision/transforms/_container.py new file mode 100644 index 0000000..dcc517b --- /dev/null +++ b/torchvision/transforms/_container.py @@ -0,0 +1,93 @@ +from typing import Any, List + +import torch + +from ._transform import _TransformBase + +__all__ = ["Compose", "RandomApply", "RandomChoice", "RandomOrder"] + + +class _ContainerTransform(_TransformBase): + def supports(self, obj: Any) -> bool: + raise NotImplementedError() + + def forward(self, *inputs: Any, strict: bool = False) -> Any: + raise NotImplementedError() + + def _make_repr(self, lines: List[str]) -> str: + extra_repr = self.extra_repr() + if extra_repr: + lines = [self.extra_repr(), *lines] + head = f"{type(self).__name__}(" + tail = ")" + body = [f" {line.rstrip()}" for line in lines] + return "\n".join([head, *body, tail]) + + +class _WrapperTransform(_ContainerTransform): + def __init__(self, transform: _TransformBase): + super().__init__() + self._transform = transform + + def supports(self, obj: Any) -> bool: + return self._transform.supports(obj) + + def __repr__(self) -> str: + return self._make_repr(repr(self._transform).splitlines()) + + +class _MultiTransform(_ContainerTransform): + def __init__(self, *transforms: _TransformBase) -> None: + super().__init__() + self._transforms = transforms + + def supports(self, obj: Any, *, strict: bool = False) -> bool: + aggregator = all if strict else any + return aggregator(transform.supports(obj) for transform in self._transforms) + + def __repr__(self) -> str: + lines = [] + for idx, transform in enumerate(self._transforms): + partial_lines = repr(transform).splitlines() + lines.append(f"({idx:d}): {partial_lines[0]}") + lines.extend(partial_lines[1:]) + return self._make_repr(lines) + + +class Compose(_MultiTransform): + def forward(self, *inputs: Any, strict: bool = False) -> Any: + for transform in self._transforms: + inputs = transform(*inputs, strict=strict) + return inputs + + +class RandomApply(_WrapperTransform): + def __init__(self, transform: _TransformBase, *, p: float = 0.5) -> None: + super().__init__(transform) + self._p = p + + def forward(self, *inputs: Any, strict: bool = False) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if float(torch.rand(())) < self._p: + # TODO: Should we check here is sample is supported if strict=True? + return sample + + return self._transform(sample, strict=strict) + + def extra_repr(self) -> str: + return f"p={self._p}" + + +class RandomChoice(_MultiTransform): + def forward(self, *inputs: Any, strict: bool = False) -> Any: + idx = torch.randint(len(self._transforms), size=()).item() + transform = self._transforms[idx] + return transform(*inputs, strict=strict) + + +class RandomOrder(_MultiTransform): + def forward(self, *inputs: Any, strict: bool = False) -> Any: + for idx in torch.randperm(len(self._transforms)): + transform = self._transforms[idx] + inputs = transform(*inputs, strict=strict) + return inputs diff --git a/torchvision/transforms/_geometry.py b/torchvision/transforms/_geometry.py new file mode 100644 index 0000000..00906da --- /dev/null +++ b/torchvision/transforms/_geometry.py @@ -0,0 +1,136 @@ +from typing import Any, Dict, Optional, Tuple, Type, Union + +import torch +from torch.nn.functional import interpolate + +from torchvision.datasets.utils import Query +from torchvision.features import BoundingBox, Feature, Image, Segmentation +from torchvision.transforms import Transform +from torchvision.transforms.utils import ImageRequirement + +__all__ = [ + "HorizontalFlip", + "Resize", + "RandomResize", + "Crop", + "RandomCrop", + "CenterCrop", +] + + +class HorizontalFlip(Transform): + @staticmethod + def image(input: Image) -> Image: + return Image(input.flip((-1,))) + + @staticmethod + def segmentation(input: Segmentation) -> Segmentation: + return Segmentation(HorizontalFlip.image(input)) + + @staticmethod + def bounding_box(input: BoundingBox) -> BoundingBox: + x, y, w, h = input.convert("xywh").to_parts() + x = input.image_size[1] - (x + w) + return BoundingBox.from_parts(x, y, w, h, like=input, format="xywh") + + +class Resize(Transform): + def __init__(self, size: Union[int, Tuple[int, int]], *, interpolation_mode: str = "bilinear") -> None: + super().__init__() + self.size = (size, size) if isinstance(size, int) else size + self.interpolation_mode = interpolation_mode + + def get_params(self, sample: Any) -> Union[Dict[str, Any], Dict[Type[Feature], Dict[str, Any]]]: + return { + Image: dict(size=self.size, interpolation_mode=self.interpolation_mode), + Segmentation: dict(size=self.size, interpolation_mode="nearest"), + } + + @staticmethod + @ImageRequirement.batched() + def image(input: Image, *, size: Tuple[int, int], interpolation_mode: str = "bilinear") -> Image: + return interpolate(input, size, mode=interpolation_mode) + + @staticmethod + def segmentation( + input: Segmentation, *, size: Tuple[int, int], interpolation_mode: str = "nearest" + ) -> Segmentation: + return Segmentation(Resize.image(input, size=size, interpolation_mode=interpolation_mode)) + + def extra_repr(self) -> str: + extra_repr = f"size={self.size}" + if self.interpolation_mode != "bilinear": + extra_repr += f", interpolation_mode={self.interpolation_mode}" + return extra_repr + + +class RandomResize(Transform, wraps=Resize): + def __init__(self, min_size: Union[int, Tuple[int, int]], max_size: Optional[Union[int, Tuple[int, int]]]) -> None: + super().__init__() + self.min_size = (min_size, min_size) if isinstance(min_size, int) else min_size + self.max_size = (max_size, max_size) if isinstance(max_size, int) else max_size + + def get_params(self, sample: Any) -> Dict[str, Any]: + min_height, min_width = self.min_size + max_height, max_width = self.max_size + height = int(torch.randint(min_height, max_height + 1, size=())) + width = int(torch.randint(min_width, max_width + 1, size=())) + return dict(size=(height, width)) + + def extra_repr(self) -> str: + return f"min_size={self.min_size}, max_size={self.max_size}" + + +class Crop(Transform): + def __init__(self, crop_box: BoundingBox) -> None: + super().__init__() + self.crop_box = crop_box.convert("xyxy") + + def get_params(self, sample: Any) -> Dict[str, Any]: + return dict(crop_box=self.crop_box) + + @staticmethod + def image(input: Image, *, crop_box: BoundingBox) -> Image: + # FIXME: pad input in case it is smaller than crop_box + x1, y1, x2, y2 = crop_box.convert("xyxy").to_parts() + return input[..., y1 : y2 + 1, x1 : x2 + 1] + + @staticmethod + def segmentation(input: Segmentation, *, crop_box) -> Segmentation: + return Segmentation(Crop.image(input, crop_box=crop_box)) + + +class CenterCrop(Transform, wraps=Crop): + def __init__(self, crop_size: Union[int, Tuple[int, int]]) -> None: + super().__init__() + self.crop_size = (crop_size, crop_size) if isinstance(crop_size, int) else crop_size + + def get_params(self, sample: Any) -> Dict[str, Any]: + image_size = Query(sample).image_size() + image_height, image_width = image_size + cx = image_width // 2 + cy = image_height // 2 + h, w = self.crop_size + crop_box = BoundingBox.from_parts(cx, cy, w, h, image_size=image_size, format="cxcywh") + return dict(crop_box=crop_box.convert("xyxy")) + + def extra_repr(self) -> str: + return f"crop_size={self.crop_size}" + + +class RandomCrop(Transform, wraps=Crop): + def __init__(self, crop_size: Union[int, Tuple[int, int]]) -> None: + super().__init__() + self.crop_size = (crop_size, crop_size) if isinstance(crop_size, int) else crop_size + + def get_params(self, sample: Any) -> Dict[str, Any]: + image_size = Query(sample).image_size() + image_height, image_width = image_size + crop_height, crop_width = self.crop_size + x = torch.randint(0, image_width - crop_width + 1, size=()) if crop_width < image_width else 0 + y = torch.randint(0, image_height - crop_height + 1, size=()) if crop_height < image_height else 0 + crop_box = BoundingBox.from_parts(x, y, crop_width, crop_height, image_size=image_size, format="xywh") + return dict(crop_box=crop_box.convert("xyxy")) + + def extra_repr(self) -> str: + return f"crop_size={self.crop_size}" diff --git a/torchvision/transforms/_misc.py b/torchvision/transforms/_misc.py new file mode 100644 index 0000000..8e18ec6 --- /dev/null +++ b/torchvision/transforms/_misc.py @@ -0,0 +1,31 @@ +from typing import Any, Dict, Sequence + +import torch + +from torchvision.features import Image +from torchvision.transforms import Transform + +__all__ = ["Normalize"] + + +class Normalize(Transform): + def __init__(self, mean: Sequence[float], std: Sequence[float]): + super().__init__() + self.mean = mean + self.std = std + + def get_params(self, sample: Any) -> Dict[str, Any]: + return dict(mean=self.mean, std=self.std) + + @staticmethod + def _channel_stats_to_tensor(stats: Sequence[float], *, like: torch.Tensor) -> torch.Tensor: + return torch.as_tensor(stats, device=like.device, dtype=like.dtype).view(-1, 1, 1) + + @staticmethod + def image(input: Image, *, mean: Sequence[float], std: Sequence[float]) -> Image: + mean_t = Normalize._channel_stats_to_tensor(mean, like=input) + std_t = Normalize._channel_stats_to_tensor(std, like=input) + return (input - mean_t) / std_t + + def extra_repr(self) -> str: + return f"mean={tuple(self.mean)}, std={tuple(self.std)}" diff --git a/torchvision/transforms/_transform.py b/torchvision/transforms/_transform.py new file mode 100644 index 0000000..f644c82 --- /dev/null +++ b/torchvision/transforms/_transform.py @@ -0,0 +1,420 @@ +import collections.abc +import difflib +import inspect +import re +import warnings +from typing import Any, Callable, Dict, Optional, Type, Union, cast + +import torch +from torch import nn + +from torchvision import features + +__all__ = ["Transform", "Identity", "Lambda"] + + +class _TransformBase(nn.Module): + _BUILTIN_FEATURE_TYPES = ( + features.BoundingBox, + features.Image, + features.Segmentation, + ) + + +# TODO: Maybe we should name this 'SampleTransform'? +class Transform(_TransformBase): + """Base class for transforms. + + A transform operates on a full sample at once, which might be a nested container of elements to transform. The + non-container elements of the sample will be dispatched to feature transforms based on their type in case it is + supported by the transform. Each transform needs to define at least one feature transform, which is canonical done + as static method: + + .. code-block:: + + class ImageIdentity(Transform): + @staticmethod + def image(input): + return input + + To achieve correct results for a complete sample, each transform should implement feature transforms for every + :class:`Feature` it can handle: + + .. code-block:: + + class Identity(Transform): + @staticmethod + def image(input): + return input + + @staticmethod + def bounding_box(input): + return input + + ... + + If the name of a static method in camel-case matches the name of a :class:`Feature`, the feature transform is + auto-registered. Supported pairs are: + + +----------------+----------------+ + | method name | `Feature` | + +================+================+ + | `image` | `Image` | + +----------------+----------------+ + | `bounding_box` | `BoundingBox` | + +----------------+----------------+ + | `segmentation` | `Segmentation` | + +----------------+----------------+ + + If you don't want to stick to this scheme, you can disable the auto-registration and perform it manually: + + .. code-block:: + + def my_image_transform(input): + ... + + class MyTransform(Transform, auto_register=False): + def __init__(self): + super().__init__() + self.register_feature_transform(Image, my_image_transform) + self.register_feature_transform(BoundingBox, self.my_bounding_box_transform) + + @staticmethod + def my_bounding_box_transform(input): + ... + + In any case, the registration will assert that the feature transform can be invoked with + ``feature_transform(input, **params)``. + + .. warning:: + + Feature transforms are **registered on the class and not on the instance**. This means you cannot have two + instances of the same :class:`Transform` with different feature transforms. + + If the feature transforms needs additional parameters, you need to + overwrite the :meth:`~Transform.get_params` method. It needs to return the parameter dictionary that will be + unpacked and its contents passed to each feature transform: + + .. code-block:: + + class Rotate(Transform): + def __init__(self, degrees): + super().__init__() + self.degrees = degrees + + def get_params(self, sample): + return dict(degrees=self.degrees) + + def image(input, *, degrees): + ... + + The :meth:`~Transform.get_params` method will be invoked once per sample. Thus, in case of randomly sampled + parameters they will be the same for all features of the whole sample. + + .. code-block:: + + class RandomRotate(Transform) + def __init__(self, range): + super().__init__() + self._dist = torch.distributions.Uniform(rangem range) + + def get_params(self, sample): + return dict(degrees=self._dist.sample().item()) + + @staticmethod + def image(input, *, degrees): + ... + + In case the sampling depends on one or more features at runtime, the complete ``sample`` gets passed to the + :meth:`Transform.get_params` method. Derivative transforms that only changes the parameter sampling, but the + feature transformations are identical, can simply wrap the transform they dispatch to: + + .. code-block:: + + class RandomRotate(Transform, wraps=Rotate): + def get_params(self, sample): + return dict(degrees=float(torch.rand(())) * 30.0) + + To transform a sample, you simply call an instance of the transform with it: + + .. code-block:: + + transform = MyTransform() + sample = dict(input=Image(torch.tensor(...)), target=BoundingBox(torch.tensor(...)), ...) + transformed_sample = transform(sample) + + By default elements in the ``sample`` that are not supported by the transform are returned without modification. + You can set the ``strict=True`` flag to force a transformation of every element or bail out in case one is not + supported. + + .. note:: + + To use a :class:`Transform` with a dataset, simply use it as map: + + .. code-block:: + + torchvision.datasets.load(...).map(MyTransform()) + """ + + _FEATURE_NAME_MAP = { + "_".join([part.lower() for part in re.findall("[A-Z][^A-Z]*", feature_type.__name__)]): feature_type + for feature_type in _TransformBase._BUILTIN_FEATURE_TYPES + } + + def __init_subclass__( + cls, *, wraps: Optional[Type["Transform"]] = None, auto_register: bool = True, verbose: bool = False + ): + cls._feature_transforms: Dict[Type[features.Feature], Callable] = ( + {} if wraps is None else wraps._feature_transforms.copy() + ) + if auto_register: + cls._auto_register(verbose=verbose) + + @staticmethod + def _has_allowed_signature(feature_transform: Callable) -> bool: + """Checks if ``feature_transform`` can be invoked with ``feature_transform(input, **params)``""" + + parameters = tuple(inspect.signature(feature_transform).parameters.values()) + if not parameters: + return False + elif len(parameters) == 1: + return parameters[0].kind != inspect.Parameter.KEYWORD_ONLY + else: + return parameters[1].kind != inspect.Parameter.POSITIONAL_ONLY + + @classmethod + def register_feature_transform(cls, feature_type: Type[features.Feature], transform: Callable) -> None: + """Registers a transform for given feature on the class. + + If a transform object is called or :meth:`Transform.apply` is invoked, inputs are dispatched to the registered + transforms based on their type. + + Args: + feature_type: Feature type the transformation is registered for. + transform: Feature transformation. + + Raises: + TypeError: If ``transform`` cannot be invoked with ``transform(input, **params)``. + """ + if not cls._has_allowed_signature(transform): + raise TypeError("Transform cannot be invoked with transform(input, **params)") + cls._feature_transforms[feature_type] = transform + + @classmethod + def _auto_register(cls, *, verbose: bool = False) -> None: + """Auto-registers methods on the class as feature transforms if they meet the following criteria: + + 1. They are static. + 2. They can be invoked with `cls.feature_transform(input, **params)`. + 3. They are public. + 4. Their name in camel case matches the name of a builtin feature, e.g. 'bounding_box' and 'BoundingBox'. + + The name from 4. determines for which feature the method is registered. + + .. note:: + + The ``auto_register`` and ``verbose`` flags need to be passed as keyword arguments to the class: + + .. code-block:: + + class MyTransform(Transform, auto_register=True, verbose=True): + ... + + Args: + verbose: If ``True``, prints to STDOUT which methods were registered or why a method was not registered + """ + for name, value in inspect.getmembers(cls): + # check if attribute is a static method and was defined in the subclass + # TODO: this needs to be revisited to allow subclassing of custom transforms + if not (name in cls.__dict__ and inspect.isfunction(value)): + continue + + not_registered_prefix = f"{cls.__name__}.{name}() was not registered as feature transform, because" + + if not cls._has_allowed_signature(value): + if verbose: + print(f"{not_registered_prefix} it cannot be invoked with {name}(input, **params).") + continue + + if name.startswith("_"): + if verbose: + print(f"{not_registered_prefix} it is private.") + continue + + try: + feature_type = cls._FEATURE_NAME_MAP[name] + except KeyError: + if verbose: + msg = f"{not_registered_prefix} its name doesn't match any known feature type." + suggestions = difflib.get_close_matches(name, cls._FEATURE_NAME_MAP.keys(), n=1) + if suggestions: + msg = ( + f"{msg} Did you mean to name it '{suggestions[0]}' " + f"to be registered for type '{cls._FEATURE_NAME_MAP[suggestions[0]].__name__}'?" + ) + print(msg) + continue + + cls.register_feature_transform(feature_type, value) + if verbose: + print( + f"{cls.__name__}.{name}() was registered as feature transform for type '{feature_type.__name__}'." + ) + + @classmethod + def from_callable( + cls, + feature_transform: Union[Callable, Dict[Type[features.Feature], Callable]], + *, + name: str = "FromCallable", + get_params: Optional[Union[Dict[str, Any], Callable[[Any], Dict[str, Any]]]] = None, + ) -> "Transform": + """Creates a new transform from a callable. + + Args: + feature_transform: Feature transform that will be registered to handle :class:`Image`'s. Can be passed as + dictionary in which case each key-value-pair is needs to consists of a ``Feature`` type and the + corresponding transform. + name: Name of the transform. + get_params: Parameter dictionary ``params`` that will be passed to ``feature_transform(input, **params)``. + Can be passed as callable in which case it will be called with the transform instance (``self``) and + the input of the transform. + + Raises: + TypeError: If ``feature_transform`` cannot be invoked with ``feature_transform(input, **params)``. + """ + if get_params is None: + get_params = dict() + attributes = dict( + get_params=get_params if callable(get_params) else lambda self, sample: get_params, + ) + transform_cls = cast(Type[Transform], type(name, (cls,), attributes)) + + if callable(feature_transform): + feature_transform = {features.Image: feature_transform} + for feature_type, transform in feature_transform.items(): + transform_cls.register_feature_transform(feature_type, transform) + + return transform_cls() + + @classmethod + def supports(cls, obj: Any) -> bool: + """Checks if object or type is supported. + + Args: + obj: Object or type. + """ + # TODO: should this handle containers? + feature_type = obj if isinstance(obj, type) else type(obj) + return feature_type is torch.Tensor or feature_type in cls._feature_transforms + + @classmethod + def apply(cls, input: torch.Tensor, **params: Any) -> torch.Tensor: + """Applies the registered feature transform to the input based on its type. + + This can be uses as type generic functional interface: + + .. code-block:: + + transform = Rotate.apply + transformed_image = transform(Image(torch.tensor(...)), degrees=30.0) + transformed_bbox = transform(BoundingBox(torch.tensor(...)), degrees=-10.0) + + Args: + input: ``input`` in ``feature_transform(input, **params)`` + **params: Parameter dictionary ``params`` in ``feature_transform(input, **params)``. + + Returns: + Transformed input. + """ + feature_type = type(input) + if not cls.supports(feature_type): + raise TypeError(f"{cls.__name__}() is not able to handle inputs of type {feature_type}.") + + # TODO: if the other domain libraries adopt our approach, we need to make the default type variable. + if feature_type is torch.Tensor: + feature_type = features.Image + input = feature_type.from_tensor(input) + + feature_transform = cls._feature_transforms[feature_type] + return feature_transform(input, **params) + + def _apply_recursively( + self, sample: Any, *, params: Union[Dict[str, Any], Dict[Type[features.Feature], Dict[str, Any]]], strict: bool + ) -> Any: + """Recurses through a sample and invokes :meth:`Transform.apply` on non-container elements. + + If an element is not supported by the transform, it is returned untransformed. + + Args: + sample: Sample. + params: Parameter dictionary ``params`` that will be passed to ``feature_transform(input, **params)``. + strict: If ``True``, raises an error in case a non-container element of the ``sample`` is not supported by + the transform. + + Raises: + TypeError: If ``strict=True`` and a non-container element of the ``sample`` is not supported. + """ + if isinstance(sample, collections.abc.Sequence): + return [self._apply_recursively(item, params=params, strict=strict) for item in sample] + elif isinstance(sample, collections.abc.Mapping): + return {name: self._apply_recursively(item, params=params, strict=strict) for name, item in sample.items()} + else: + feature_type = type(sample) + if not self.supports(feature_type): + if not strict: + return sample + + raise TypeError(f"{type(self).__name__}() is not able to handle inputs of type {feature_type}.") + + if params and all(isinstance(key, type) and issubclass(key, features.Feature) for key in params.keys()): + params = params[feature_type] + return self.apply(sample, **params) + + def get_params(self, sample: Any) -> Union[Dict[str, Any], Dict[Type[features.Feature], Dict[str, Any]]]: + """Returns the parameter dictionary used to transform the current sample. + + .. note:: + + Since ``sample`` might be a nested container, it is recommended to use the + :class:`torchvision.datasets.utils.Query` class if you need to extract information from it. + + Args: + sample: Current sample. + + Returns: + Parameter dictionary ``params`` in ``feature_transform(input, **params)``. + """ + return dict() + + def forward( + self, + *inputs: Any, + params: Optional[Union[Dict[str, Any], Dict[Type[features.Feature], Dict[str, Any]]]] = None, + strict: bool = True, + ) -> Any: + if not self._feature_transforms: + raise RuntimeError(f"{type(self).__name__}() has no registered feature transform.") + + sample = inputs if len(inputs) > 1 else inputs[0] + if params is None: + params = self.get_params(sample) + return self._apply_recursively(sample, params=params, strict=strict) + + +class Identity(Transform): + """Identity transform that supports all built-in :class:`Features`.""" + + def __init__(self): + super().__init__() + for feature_type in self._BUILTIN_FEATURE_TYPES: + self.register_feature_transform(feature_type, lambda input, **params: input) + + +class Lambda(Transform): + def __new__(cls, lambd: Callable) -> Transform: + warnings.warn("transforms.Lambda(...) is deprecated. Use transforms.Transform.from_callable(...) instead.") + # We need to generate a new class everytime a Lambda transform is created, since the feature transforms are + # registered on the class rather than on the instance. If we didn't, registering a feature transform will + # overwrite it on **all** Lambda transform instances. + return Transform.from_callable(lambd, name="Lambda") diff --git a/torchvision/transforms/utils.py b/torchvision/transforms/utils.py new file mode 100644 index 0000000..87cdd47 --- /dev/null +++ b/torchvision/transforms/utils.py @@ -0,0 +1,32 @@ +import functools +from typing import Any + +from torchvision.features import Image + +__all__ = ["Requirement", "ImageRequirement"] + + +class Requirement: + pass + + +class ImageRequirement(Requirement): + @classmethod + def batched(cls, noop_if_single: bool = False): + def outer_wrapper(feature_transform): + @functools.wraps(feature_transform) + def inner_wrapper(input: Image, **params: Any) -> Image: + if noop_if_single and input.batch_size <= 1: + return input + elif input.batch_size >= 1: + return feature_transform(input, **params) + + output = feature_transform(input.batch(), **params) + if output.batch_size <= 1: + output = output.unbatch() + + return output + + return inner_wrapper + + return outer_wrapper diff --git a/transforms.md b/transforms.md new file mode 100644 index 0000000..089d490 --- /dev/null +++ b/transforms.md @@ -0,0 +1,189 @@ +# RFC: Transformations + +One major concern of the upstream implementation of datasets in `torchvision.dataset` are the optional `transform`, `target_transform`, and `transforms`. Apart from not being used in a [standardized way across datasets](https://gist.github.com/pmeier/14756fe0501287b2974e03ab8d651c10), from a theoretical standpoint the transformation has nothing to do with the dataset. + +We already decided that the rework will remove transforms from datasets. This means we need to provide an alternative that offers a similarly easy interface for simple use cases, while still being flexible enough to handle complex use cases. Opposed to the upstream implementation, the new API will return a datapipe and each sample drawn from it will be a dictionary that holds the data. This means we cannot simply use `dataset.map(torchvision.transforms.HorizontalFlip())` without modification. + +This post outlines how I envision the dataset-transformation-interplay with the reworked dataset API. You can find a PoC implementation in the same PR. + +## `Feature`'s of a dataset sample + +PyTorch is flexible enough to allow powerful subclassing of the core class: the `torch.Tensor`. Thus instead of having only the raw `Tensor`, we could add some custom `Feature` classes that represent the individual elements returned by a dataset. These classes could have special fields + +```python +class Image(Feature): + @property + def image_size(self) -> Tuple[int, int]: + return self.shape[-2:] +``` + +or methods + +```python +class BoundingBox(Feature): + @property + def image_size(self) -> Tuple[int, int]: + ... + + @property + def format(self) -> str: + ... + + def convert(self, format: str) -> "BoundingBox": + ... +``` + +that would make interacting with them a lot easier. Since for all intents and purposes they still act like regular tensors, a user does not have to worry about this at all. Plus, by returning these specific types instead of raw `Tensor`'s for a dataset, a transformation can know what to do with a passed argument. + +## `Transform`'ing a dataset sample + +Passing `transform=HorizontalFlip()` to the constructor of a dataset is hard to beat in terms of UX. Since we already decided, that this will not be a feature after the rework, the next best thing to apply a transform as a map to each sample, i.e. `dataset = dataset.map(HorizontalFlip())`. Unfortunately, this is not possible with our current transforms, since they cannot deal with a dictionary as input. In particular, all current transformations assume the input is an image. + +The new API should have the following features: + +1. Each transform should know which features it is able to handle and should do so without needing to explicitly calling anything. For example, `HorizontalFlip` needs to handle `Image`'s as well as `BoundingBox`'es. +2. The interface should be kept BC, i.e. `HorizontalFlip()(torch.rand(3, 16, 16))` should still work. +3. The transform should be able to handle the dataset output directly, i.e. a (possibly nested) dictionary of features. This means, by default inputs that are not supported should be ignored by the transform and returned without modification. +4. Apart from passing a multi-feature sample as a dictionary, it should also be possible to multiple arguments to the transform for convenience, e.g. `image, bounding_box = HorizontalFlip()(image, bounding_box)`. + +Points 2. - 4. only concern the dispatch, which can be handled in a custom `Transform` and thus be hidden from users as well someone who writes a new transform. Ignoring that, a transform could look like this: + +```python +class HorizontalFlip(Transform): + @staticmethod + def image(input: Image) -> Image: + return Image(input.flip((-1,))) + + @staticmethod + def bounding_box(input: BoundingBox) -> BoundingBox: + x, y, w, h = input.convert("xywh").to_parts() + x = input.image_size[1] - (x + w) + return BoundingBox.from_parts(x, y, w, h, image_size=input.image_size, format="xywh") +``` + +This is as simple as it gets (no I didn't leave out anything here; this is actually the full implementation!): we have a separate feature transforms to handle the different elements of the sample. As is, `HorizontalFlip()` will only transform `Image`'s and `BoundingBox`'es. If we later want to add support for `KeyPoints`, we can simply add a `key_points(input)` feature transform and the dispatch mechanism will handle everything else in the background. + +The next complexity step are transformations that need additional parameters. Since the feature transforms are static methods, we do not have access to instance attributes. To solve this, every transform can overwrite the `get_params` method. It should return a parameter dictionary that will be unpacked and passed to each transform: + +```python +class Rotate(Transform): + def __init__(self, degrees: float) -> None: + super().__init__() + self.degrees = degrees + + def get_params(self, sample: Any = None) -> Dict[str, Any]: + return dict(degrees=self.degrees) + + @staticmethod + def image(input: Image, *, degrees: torch.Tensor) -> Image: + return input + + @staticmethod + def bounding_box(input: BoundingBox, *, degrees: torch.Tensor) -> BoundingBox: + return input +``` + +This is a little more verbose than how our "normal" transformations work, but is very similar to interface of our random transforms, which will be discussed later. + +The proposed approach will be sufficient to handle all transformations if the following assumptions hold: + +1. The transform is applied to the elements of a single sample. +2. The parameters needed to perform the transform are either static or can be obtained from the sample at runtime. +3. After the parameter sampling, each feature can be transformed independent of the others. + +## Functional interface + +Since the actual feature transforms are static, they can completely replace the functional API that we currently have. For example, instead of using `transforms.functional.horizontal_flip(...)` we now can use `transforms.HorizontalFlip.image(...)`. Even better, through `Transform.apply` we have access to the same dispatch mechanism the stateful variant uses: + +```python +transform = Rotate.apply +transformed_image = transform(image, degrees=30.0) +transformed_bbox = transform(bbox, degrees=-10.0) +``` + +One design principle that follows from the clean separation of parameter sampling and feature transforms, is that the latter should not contain be deterministic (ignoring non-deterministic behavior of PyTorch operators), i.e. do not contain any random elements. A conclusion from this is that random tranforms should not expose public feature transforms, but rather wrap the ones from the deterministic variant. For example, internally `RandomRotate` should call `Rotate.image()`, but should not expose `RandomRotate.image()` itself to avoid confusion. + +## Random transforms + +There are in general two types of random transforms: + +1. Transforms that sample their parameters at random for each sample, but are always applied. +2. Transforms that are applied at random given a probability, but have fixed parameters. + +These cases are not mutually exclusive, but since they address independent concepts they also can be handled independently. + +### Random parameters + +Since the dispatch for the complete sample happens from a single point in the `forward` method, it is easy to perform all feature transforms with the same random parameters. For example, for `RandomRotate` the implementation might look like + +```python +class RandomRotate(Transform): + def __init__(self, low: float, high: float) -> None: + super().__init__() + self._dist = torch.distributions.Uniform(low, high) + + def get_params(self, sample: Any = None) -> Dict[str, Any]: + return dict(degrees=self._dist.sample().item()) + + # The feature transforms are just defined for illustration purposes + # and should actually be hidden as explained above + @staticmethod + def image(input: Image, *, degrees: torch.Tensor) -> Image: + return Rotate.image(input, degrees=degrees) + + @staticmethod + def bounding_box(input: BoundingBox, *, degrees: torch.Tensor) -> BoundingBox: + return Rotate.image(input, degrees=degrees) +``` + +Since the `get_params()` methods gets passes the complete sample, the sampled parameters can also depend on the input, which is required for some transforms. For example, [`RandomErasing`](https://github.com/pytorch/vision/blob/f3aff2fa60c110c25df671b6f99ffb26727cb8ae/torchvision/transforms/transforms.py#L1608) needs to know image size to determine location and size of the erased patch. + +### Random application + +Since the randomness in these transforms is independent of the actual feature transforms, it can be completely separated from the transform. One clean approach is to provide building blocks and let the user combine them as needed. For example, instead of providing `RandomHorizontalFlip`, we would provide `RandomApply` and `HorizontalFlip` such that `RandomHorizontalFlip() == RandomApply(HorizontalFlip())`. This has three advantages: + +1. `RandomApply` can be reused for other transforms than `HorizontalFlip` +2. The implementation of a new transform is less error-prone, since a transform like `HorizontalFlip` does not need to deal with the randomness. +3. Some transforms, for example color space transforms, have use cases as deterministic and random variant. With this approach we have access to both without defining two distinct transforms. + + +## Upgrade guide + +Upgrading a simple image-only transformation + +```python +class Foo(nn.Module): + def __init__(self, bar, baz="baz"): + super().__init__() + self.bar = bar + self.baz = baz + + def forward(self, input): + return foo(input, self.bar, baz=self.baz) +``` + +is straight forward. + +```python +class Foo(Transform): + def __init__(self, bar, baz="baz"): + super().__init__() + self.bar = bar + self.baz = baz + + def get_params(self): + return dict(bar=self.bar, baz=self.baz) + + @staticmethod + def image(input, *, bar, baz): + return foo(input, bar, baz=baz) +``` + +1. The new style is a little more verbose. This is due to the fact that `Transform`'s are now able to handle multiple input types and a single source of parameters is needed for them. For random transforms the verbosity stays the same, since they already have this single source for the random parameters. +2. If `foo` only takes one positional argument, defining a custom `image` method could also be replaced by `self.register_feature_transform(Image, foo)` in `__init__()`. + + +## Other implications + +- With this proposal, the new `Transform`'s will only work well with samples from the datasets, if we wrap everything in the proposed custom `Feature` classes. Since we own the datasets that won't be an issue for the builtin variants. This only becomes an issue if someone wants to use custom datasets while retaining full capability of the proposed transforms. In such a case I think it is reasonable to require them to adhere to the structure and also use the correct `Feature` types in their custom dataset. +- Since the `Feature`'s are custom `Tensor`'s, the `Transform`'s will no longer work with `PIL` images. AFAIK, the plan to drop `PIL` support is not new at all. If `torchvision.io` is able to handle all kinds of image I/O, we could also completely remove `PIL` as dependency. If `PIL` is available nevertheless, we can provide `pil_to_tensor` and `tensor_to_pil` functions under `torch.utils` for convenience, but don't rely on them in the normal workflow. \ No newline at end of file diff --git a/transforms_poc.py b/transforms_poc.py new file mode 100644 index 0000000..457218f --- /dev/null +++ b/transforms_poc.py @@ -0,0 +1,33 @@ +import torch + +from torchvision import transforms +from torchvision.features import BoundingBox, Image + +image = Image(torch.rand(3, 10, 10)) +image_e = torch.flip(image, (-1,)) + +bbox = BoundingBox(torch.tensor([7, 3, 9, 8]), image_size=image.image_size, format="xyxy") +bbox_e = BoundingBox(torch.tensor([1, 3, 3, 8]), image_size=image_e.image_size, format="xyxy") + +transform = transforms.HorizontalFlip() +functional_transform = transforms.HorizontalFlip.apply + +torch.testing.assert_close(transform(image), image_e) +torch.testing.assert_close(functional_transform(image), image_e) + +torch.testing.assert_close(transform(bbox).convert("xyxy"), bbox_e) +torch.testing.assert_close(functional_transform(bbox).convert("xyxy"), bbox_e) + +sample_a = transform(dict(image=image, bbox=bbox)) +torch.testing.assert_close(sample_a["image"], image_e) +torch.testing.assert_close(sample_a["bbox"].convert("xyxy"), bbox_e) + +image_a, bbox_a = transform(image, bbox) +torch.testing.assert_close(image_a, image_e) +torch.testing.assert_close(bbox_a.convert("xyxy"), bbox_e) + +composed_transform = transforms.Compose(transform, transform) + +image_a, bbox_a = composed_transform(image, bbox) +torch.testing.assert_close(image_a, image) +torch.testing.assert_close(bbox_a.convert("xyxy"), bbox)