diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 3826293f3ed..364b26fa995 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -26,7 +26,8 @@ from torchvision.ops.boxes import box_iou from torchvision.prototype import datapoints, transforms from torchvision.prototype.transforms.utils import check_type -from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image + +from torchvision.transforms.functional import get_image_size, InterpolationMode, pil_to_tensor, to_pil_image BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] @@ -1924,3 +1925,59 @@ def test__transform(self, inpt): assert type(output) is type(inpt) assert output.shape[-4] == num_samples assert output.dtype == inpt.dtype + + +class TestMixupDetection: + def create_fake_image(self, mocker, image_type, *, size=(32, 32), color=123): + if image_type == PIL.Image.Image: + return PIL.Image.new("RGB", size, color) + return mocker.MagicMock(spec=image_type) + + @pytest.mark.parametrize("ratio", [0.0, 0.5, 1.0]) + def test__mixup(self, mocker, ratio): + image_1 = self.create_fake_image(mocker, PIL.Image.Image, size=(128, 128), color=(124, 124, 124)) + image_1 = pil_to_tensor(image_1) + target_1 = { + "boxes": datapoints.BoundingBox( + torch.tensor([[0.0, 0.0, 10.0, 10.0], [20.0, 20.0, 30.0, 30.0]]), + format="XYXY", + spatial_size=get_image_size(image_1), + ), + "labels": datapoints.Label(torch.tensor([1, 2])), + } + sample_1 = { + "image": image_1, + "boxes": target_1["boxes"], + "labels": target_1["labels"], + } + + image_2 = self.create_fake_image(mocker, PIL.Image.Image, size=(128, 128), color=(0, 0, 0)) + image_2 = pil_to_tensor(image_2) + target_2 = { + "boxes": datapoints.BoundingBox( + torch.tensor([[10.0, 0.0, 20.0, 20.0], [10.0, 20.0, 30.0, 30.0]]), + format="XYXY", + spatial_size=get_image_size(image_2), + ), + "labels": datapoints.Label(torch.tensor([2, 3])), + } + sample_2 = { + "image": image_2, + "boxes": target_2["boxes"], + "labels": target_2["labels"], + } + + transform = transforms.MixupDetection() + output = transform._mixup(sample_1, sample_2, ratio) + + if ratio == 0: + assert output == sample_2 + + elif ratio == 1: + assert output == sample_1 + + elif ratio == 0.5: + # TODO: Fix this test + assert output["image"] == (np.asarray(image_1) + np.asarray(image_2)) / 2 + assert output["boxes"] == torch.cat([target_1["boxes"], target_2["boxes"]]) + assert output["labels"] == torch.cat([target_1["labels"], target_2["labels"]]) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 04b007190b8..13eb216813e 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -5,7 +5,7 @@ from ._transform import Transform # usort: skip from ._presets import StereoMatching # usort: skip -from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste +from ._augment import MixupDetection, RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide from ._color import ( ColorJitter, diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 3160770a09d..0ff48c0c3b1 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -1,20 +1,22 @@ import math import numbers import warnings -from typing import Any, cast, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union import PIL.Image -import torch -from torch.utils._pytree import tree_flatten, tree_unflatten +import torch from torchvision.ops import masks_to_boxes from torchvision.prototype import datapoints -from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform +from torchvision.prototype.transforms import functional as F, InterpolationMode -from ._transform import _RandomApplyTransform +from ._transform import _DetectionBatchTransform, _RandomApplyTransform from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size +D = TypeVar("D", bound=datapoints._datapoint.Datapoint) + + class RandomErasing(_RandomApplyTransform): _transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video) @@ -191,7 +193,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt -class SimpleCopyPaste(Transform): +class SimpleCopyPaste(_DetectionBatchTransform): def __init__( self, blending: bool = True, @@ -203,184 +205,188 @@ def __init__( self.blending = blending self.antialias = antialias - def _copy_paste( - self, - image: datapoints.TensorImageType, - target: Dict[str, Any], - paste_image: datapoints.TensorImageType, - paste_target: Dict[str, Any], - random_selection: torch.Tensor, - blending: bool, - resize_interpolation: F.InterpolationMode, - antialias: Optional[bool], - ) -> Tuple[datapoints.TensorImageType, Dict[str, Any]]: - - paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection]) - paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection]) - paste_labels = paste_target["labels"].wrap_like( - paste_target["labels"], paste_target["labels"][random_selection] + def forward(self, *inputs: Any) -> Any: + flat_batch_with_spec, batch = self._flatten_and_extract_data( + inputs, + image=(datapoints.Image, PIL.Image.Image, is_simple_tensor), + boxes=(datapoints.BoundingBox,), + masks=(datapoints.Mask,), + labels=(datapoints.Label, datapoints.OneHotLabel), ) + batch = self._to_image_tensor(batch) - masks = target["masks"] - - # We resize source and paste data if they have different sizes - # This is something different to TF implementation we introduced here as - # originally the algorithm works on equal-sized data - # (for example, coming from LSJ data augmentations) - size1 = cast(List[int], image.shape[-2:]) - size2 = paste_image.shape[-2:] - if size1 != size2: - paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias) - paste_masks = F.resize(paste_masks, size=size1) - paste_boxes = F.resize(paste_boxes, size=size1) - - paste_alpha_mask = paste_masks.sum(dim=0) > 0 - - if blending: - paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0]) + batch_output = [] + for sample, sample_rolled in zip(batch, batch[-1:] + batch[:-1]): + num_masks = len(sample_rolled["masks"]) + if num_masks < 1: + # This might for example happen with the LSJ augmentation strategy + batch_output.append(sample) + continue - inverse_paste_alpha_mask = paste_alpha_mask.logical_not() - # Copy-paste images: - image = image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask)) + random_selection = torch.randint(0, num_masks, (num_masks,), device=sample_rolled["masks"].device) + random_selection = torch.unique(random_selection) - # Copy-paste masks: - masks = masks * inverse_paste_alpha_mask - non_all_zero_masks = masks.sum((-1, -2)) > 0 - masks = masks[non_all_zero_masks] + batch_output.append( + self._simple_copy_paste( + sample, + sample_rolled, + random_selection=random_selection, + blending=self.blending, + resize_interpolation=self.resize_interpolation, + antialias=self.antialias, + ) + ) - # Do a shallow copy of the target dict - out_target = {k: v for k, v in target.items()} + return self._unflatten_and_insert_data(flat_batch_with_spec, batch_output) - out_target["masks"] = torch.cat([masks, paste_masks]) + @staticmethod + def _wrapping_getitem(datapoint: D, index: Any) -> D: + return type(datapoint).wrap_like(datapoint, datapoint[index]) - # Copy-paste boxes and labels - bbox_format = target["boxes"].format - xyxy_boxes = masks_to_boxes(masks) - # masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive - # we need to add +1 to x2y2. - # There is a similar +1 in other reference implementations: + def _simple_copy_paste( + self, + sample_1: Dict[str, Any], + sample_2: Dict[str, Any], + *, + random_selection: torch.Tensor, + blending: bool, + resize_interpolation: F.InterpolationMode, + antialias: Optional[bool], + ) -> Dict[str, Any]: + dst_image = sample_1["image"] + dst_masks = sample_1["masks"] + dst_labels = sample_1["labels"] + + src_image = sample_2["image"] + src_masks = self._wrapping_getitem(sample_2["masks"], random_selection) + src_boxes = self._wrapping_getitem(sample_2["dst_boxes"], random_selection) + src_labels = self._wrapping_getitem(sample_2["labels"], random_selection) + + # In case the `dst_image` and `src_image` have different spatial sizes, we resize `src_image` and the + # corresponding annotations to `dst_image`'s spatial size. This differs from the official implementation, since + # that only works with equally sized data, e.g. coming from the LSJ augmentation strategy. + dst_spatial_size = dst_image.shape[-2:] + src_spatial_size = src_image.shape[-2:] + if dst_spatial_size != src_spatial_size: + src_image = F.resize( + src_image, size=dst_spatial_size, interpolation=resize_interpolation, antialias=antialias + ) + src_masks = F.resize(src_masks, size=dst_spatial_size) + src_boxes = F.resize(src_boxes, size=dst_spatial_size) + + src_paste_mask = src_masks.sum(dim=0, keepdim=0) > 0 + # Although the parameter is called "blending", we don't actually blend here. `src_paste_mask` is a boolean + # mask and although `F.gaussian_blur` internally converts to floating point, it will be converted back to + # boolean on the way out. Meaning, although we blur, `src_paste_mask` will have no values other than 0 or 1. + # The original paper doesn't specify how blending should be done and the official implementation is not helpful + # either: + # https://github.com/tensorflow/tpu/blob/732902a457b2a8924f885ee832830e1bf6d7c537/models/official/detection/dataloader/maskrcnn_parser_with_copy_paste.py#L331-L334 + if blending: + src_paste_mask = F.gaussian_blur(src_paste_mask, kernel_size=[5, 5], sigma=[2.0]) + dst_paste_mask = src_paste_mask.logical_not() + + image = datapoints.Image.wrap_like(dst_image, dst_image.mul(dst_paste_mask).add_(src_image.mul(src_paste_mask))) + + dst_masks = dst_masks * dst_paste_mask + # Since we paste the `src_image` into the `dst_image`, we might completely cover an object previously visible in + # `dst_image`. Furthermore, with `blending=True` small regions to begin with might also be shrunk enough to + # vanish. Thus, we check for degenerate masks and remove them. + valid_dst_masks = dst_masks.sum((-1, -2)) > 0 + dst_masks = dst_masks[valid_dst_masks] + masks = datapoints.Mask.wrap_like(dst_masks, torch.cat([dst_masks, src_masks])) + + # Since the `dst_masks` might have changed above, we recompute the corresponding `dst_boxes`. + dst_boxes_xyxy = masks_to_boxes(dst_masks) + # `masks_to_boxes` produces boxes with x2y2 inclusive, but x2y2 should be exclusive. Thus, we increase by one. + # There is a similar behavior in other reference implementations: # https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422 - xyxy_boxes[:, 2:] += 1 - boxes = F.convert_format_bounding_box( - xyxy_boxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True + dst_boxes_xyxy[:, 2:] += 1 + dst_boxes = F.convert_format_bounding_box( + dst_boxes_xyxy, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=src_boxes.format, inplace=True ) - out_target["boxes"] = torch.cat([boxes, paste_boxes]) + dst_boxes = datapoints.BoundingBox(dst_boxes, format=src_boxes.format, spatial_size=dst_spatial_size) + boxes = datapoints.BoundingBox.wrap_like(dst_boxes, torch.cat([dst_boxes, src_boxes])) - labels = target["labels"][non_all_zero_masks] - out_target["labels"] = torch.cat([labels, paste_labels]) + labels = datapoints.Label.wrap_like(dst_labels, torch.cat([dst_labels[valid_dst_masks], src_labels])) # Check for degenerated boxes and remove them - boxes = F.convert_format_bounding_box( - out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY + # FIXME: This can only happen for the `src_boxes`, right? Since `dst_boxes` were re-computed from `dst_masks` + # above, they should all be valid. If so, degenerate boxes at this stage should only come from the resizing of + # `src_boxes` above. Maybe we can remove already at that stage? + # TODO: Maybe unify this with `transforms.RemoveSmallBoundingBoxes()`? + boxes_xyxy = F.convert_format_bounding_box( + boxes, old_format=boxes.format, new_format=datapoints.BoundingBoxFormat.XYXY ) - degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + degenerate_boxes = boxes_xyxy[:, 2:].le(boxes_xyxy[:, :2]) if degenerate_boxes.any(): - valid_targets = ~degenerate_boxes.any(dim=1) - - out_target["boxes"] = boxes[valid_targets] - out_target["masks"] = out_target["masks"][valid_targets] - out_target["labels"] = out_target["labels"][valid_targets] - - return image, out_target - - def _extract_image_targets( - self, flat_sample: List[Any] - ) -> Tuple[List[datapoints.TensorImageType], List[Dict[str, Any]]]: - # fetch all images, bboxes, masks and labels from unstructured input - # with List[image], List[BoundingBox], List[Mask], List[Label] - images, bboxes, masks, labels = [], [], [], [] - for obj in flat_sample: - if isinstance(obj, datapoints.Image) or is_simple_tensor(obj): - images.append(obj) - elif isinstance(obj, PIL.Image.Image): - images.append(F.to_image_tensor(obj)) - elif isinstance(obj, datapoints.BoundingBox): - bboxes.append(obj) - elif isinstance(obj, datapoints.Mask): - masks.append(obj) - elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)): - labels.append(obj) - - if not (len(images) == len(bboxes) == len(masks) == len(labels)): - raise TypeError( - f"{type(self).__name__}() requires input sample to contain equal sized list of Images, " - "BoundingBoxes, Masks and Labels or OneHotLabels." - ) + valid_boxes = ~degenerate_boxes.any(dim=-1) + + masks = self._wrapping_getitem(masks, valid_boxes) + boxes = self._wrapping_getitem(boxes, valid_boxes) + labels = self._wrapping_getitem(labels, valid_boxes) - targets = [] - for bbox, mask, label in zip(bboxes, masks, labels): - targets.append({"boxes": bbox, "masks": mask, "labels": label}) + return dict(image=image, masks=masks, boxes=boxes, labels=labels) - return images, targets - def _insert_outputs( +class MixupDetection(_DetectionBatchTransform): + def __init__( self, - flat_sample: List[Any], - output_images: List[datapoints.TensorImageType], - output_targets: List[Dict[str, Any]], + *, + alpha: float = 1.5, ) -> None: - c0, c1, c2, c3 = 0, 0, 0, 0 - for i, obj in enumerate(flat_sample): - if isinstance(obj, datapoints.Image): - flat_sample[i] = datapoints.Image.wrap_like(obj, output_images[c0]) - c0 += 1 - elif isinstance(obj, PIL.Image.Image): - flat_sample[i] = F.to_image_pil(output_images[c0]) - c0 += 1 - elif is_simple_tensor(obj): - flat_sample[i] = output_images[c0] - c0 += 1 - elif isinstance(obj, datapoints.BoundingBox): - flat_sample[i] = datapoints.BoundingBox.wrap_like(obj, output_targets[c1]["boxes"]) - c1 += 1 - elif isinstance(obj, datapoints.Mask): - flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"]) - c2 += 1 - elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)): - flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] - c3 += 1 - - def forward(self, *inputs: Any) -> Any: - flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) - - images, targets = self._extract_image_targets(flat_inputs) - - # images = [t1, t2, ..., tN] - # Let's define paste_images as shifted list of input images - # paste_images = [t2, t3, ..., tN, t1] - # FYI: in TF they mix data on the dataset level - images_rolled = images[-1:] + images[:-1] - targets_rolled = targets[-1:] + targets[:-1] - - output_images, output_targets = [], [] - - for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled): + super().__init__() + self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) - # Random paste targets selection: - num_masks = len(paste_target["masks"]) + def _check_inputs(self, flat_inputs: List[Any]) -> None: + if has_any(flat_inputs, datapoints.Mask, datapoints.Video): + raise TypeError(f"{type(self).__name__}() is only supported for images and bounding boxes.") - if num_masks < 1: - # Such degerante case with num_masks=0 can happen with LSJ - # Let's just return (image, target) - output_image, output_target = image, target - else: - random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device) - random_selection = torch.unique(random_selection) - - output_image, output_target = self._copy_paste( - image, - target, - paste_image, - paste_target, - random_selection=random_selection, - blending=self.blending, - resize_interpolation=self.resize_interpolation, - antialias=self.antialias, - ) - output_images.append(output_image) - output_targets.append(output_target) + def forward(self, *inputs: Any) -> Any: + flat_batch_with_spec, batch = self._flatten_and_extract_data( + inputs, + image=(datapoints.Image, PIL.Image.Image, is_simple_tensor), + boxes=(datapoints.BoundingBox,), + labels=(datapoints.Label, datapoints.OneHotLabel), + ) + self._check_inputs(flat_batch_with_spec[0]) + + batch = self._to_image_tensor(batch) + + batch_output = [ + self._mixup(sample, sample_rolled, ratio=float(self._dist.sample())) + for sample, sample_rolled in zip(batch, batch[-1:] + batch[:-1]) + ] + + return self._unflatten_and_insert_data(flat_batch_with_spec, batch_output) + + def _mixup(self, sample_1: Dict[str, Any], sample_2: Dict[str, Any], *, ratio: float) -> Dict[str, Any]: + if ratio >= 1.0: + return sample_1 + elif ratio == 0.0: + return sample_2 + + h_1, w_1 = sample_1["image"].shape[-2:] + h_2, w_2 = sample_2["image"].shape[-2:] + h_mixup = max(h_1, h_2) + w_mixup = max(w_1, w_2) + + # TODO: add the option to fill this with something else than 0 + dtype = sample_1["image"].dtype if sample_1["image"].is_floating_point() else torch.float32 + mix_image = F.pad_image_tensor( + sample_1["image"].to(dtype), padding=[0, 0, w_mixup - w_1, h_mixup - h_1], fill=None + ).mul_(ratio) + mix_image[..., :h_2, :w_2] += sample_2["image"] * (1.0 - ratio) + mix_image = mix_image.to(sample_1["image"]) + + mix_boxes = datapoints.BoundingBox.wrap_like( + sample_1["boxes"], + torch.cat([sample_1["boxes"], sample_2["boxes"]], dim=-2), + spatial_size=(h_mixup, w_mixup), + ) - # Insert updated images and targets into input flat_sample - self._insert_outputs(flat_inputs, output_images, output_targets) + mix_labels = datapoints.Label.wrap_like( + sample_1["labels"], + torch.cat([sample_1["labels"], sample_2["labels"]], dim=-1), + ) - return tree_unflatten(flat_inputs, spec) + return dict(image=mix_image, boxes=mix_boxes, labels=mix_labels) diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 43224cabd38..e16adb405a1 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -2,10 +2,13 @@ from typing import Any, Callable, Dict, List, Tuple, Type, Union import PIL.Image + import torch from torch import nn -from torch.utils._pytree import tree_flatten, tree_unflatten -from torchvision.prototype.transforms.utils import check_type +from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec +from torchvision.prototype import datapoints +from torchvision.prototype.transforms import functional as F +from torchvision.prototype.transforms.utils import check_type, is_simple_tensor from torchvision.utils import _log_api_usage_once @@ -83,3 +86,85 @@ def forward(self, *inputs: Any) -> Any: ] return tree_unflatten(flat_outputs, spec) + + +class _DetectionBatchTransform(Transform): + @staticmethod + def _flatten_and_extract_data( + inputs: Any, **types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...] + ) -> Tuple[Tuple[List[Any], TreeSpec, List[Dict[str, int]]], List[Dict[str, Any]]]: + batch = inputs if len(inputs) > 1 else inputs[0] + flat_batch = [] + sample_specs = [] + + offset = 0 + batch_idcs = [] + batch_data = [] + for sample_idx, sample in enumerate(batch): + flat_sample, sample_spec = tree_flatten(sample) + flat_batch.extend(flat_sample) + sample_specs.append(sample_spec) + + sample_types_or_checks = types_or_checks.copy() + sample_idcs = {} + sample_data = {} + for flat_idx, item in enumerate(flat_sample, offset): + if not sample_types_or_checks: + break + + for key, types_or_checks_ in sample_types_or_checks.items(): + if check_type(item, types_or_checks_): + break + else: + continue + + del sample_types_or_checks[key] + sample_idcs[key] = flat_idx + sample_data[key] = item + + if sample_types_or_checks: + # TODO: improve message + raise TypeError( + f"Sample at index {sample_idx} in the batch is missing {sample_types_or_checks.keys()}`" + ) + + batch_idcs.append(sample_idcs) + batch_data.append(sample_data) + offset += len(flat_sample) + + batch_spec = TreeSpec(list, context=None, children_specs=sample_specs) + + return (flat_batch, batch_spec, batch_idcs), batch_data + + @staticmethod + def _to_image_tensor(batch: List[Dict[str, Any]], *, key: str = "image") -> List[Dict[str, Any]]: + for sample in batch: + image = sample.pop(key) + if isinstance(image, PIL.Image.Image): + image = F.pil_to_tensor(image) + elif isinstance(image, datapoints.Image): + image = image.as_subclass(torch.Tensor) + sample[key] = image + return batch + + @staticmethod + def _unflatten_and_insert_data( + flat_batch_with_spec: Tuple[List[Any], TreeSpec, List[Dict[str, int]]], + batch: List[Dict[str, Any]], + ) -> Any: + flat_batch, batch_spec, batch_idcs = flat_batch_with_spec + + for sample_idx, sample_idcs in enumerate(batch_idcs): + for key, flat_idx in sample_idcs.items(): + inpt = flat_batch[flat_idx] + item = batch[sample_idx][key] + + if not is_simple_tensor(inpt) and is_simple_tensor(item): + if isinstance(inpt, datapoints._datapoint.Datapoint): + item = type(inpt).wrap_like(inpt, item) + elif isinstance(inpt, PIL.Image.Image): + item = F.to_image_pil(item) + + flat_batch[flat_idx] = item + + return tree_unflatten(flat_batch, batch_spec)