From 9440d3c14320821fa58ee28b580dcbc4dd6999e2 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 25 Oct 2022 12:18:10 +0100 Subject: [PATCH 1/3] Moving value estimation of `RandomErasing` from runtime to constructor. --- torchvision/prototype/transforms/_augment.py | 24 +++++++++----------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 4a45f9f5788..9d896819721 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -40,7 +40,14 @@ def __init__( raise ValueError("Scale should be between 0 and 1") self.scale = scale self.ratio = ratio - self.value = value + if isinstance(value, (int, float)): + self.value = [value] + elif isinstance(value, str): + self.value = None + elif isinstance(value, tuple): + self.value = list(value) + else: + self.value = value self.inplace = inplace self._log_ratio = torch.log(torch.tensor(self.ratio)) @@ -48,16 +55,7 @@ def __init__( def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: img_c, img_h, img_w = query_chw(flat_inputs) - if isinstance(self.value, (int, float)): - value = [self.value] - elif isinstance(self.value, str): - value = None - elif isinstance(self.value, tuple): - value = list(self.value) - else: - value = self.value - - if value is not None and not (len(value) in (1, img_c)): + if self.value is not None and not (len(self.value) in (1, img_c)): raise ValueError( f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)" ) @@ -79,10 +77,10 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: if not (h < img_h and w < img_w): continue - if value is None: + if self.value is None: v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() else: - v = torch.tensor(value)[:, None, None] + v = torch.tensor(self.value)[:, None, None] i = torch.randint(0, img_h - h + 1, size=(1,)).item() j = torch.randint(0, img_w - w + 1, size=(1,)).item() From 9b9b75e58c8e8853b4a81968c8b27eec7547081b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 25 Oct 2022 13:39:36 +0100 Subject: [PATCH 2/3] Speed up mixing on MixUp/Cutmix and small optimization on SimpleCopyPaste. --- torchvision/prototype/transforms/_augment.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 9d896819721..c6a8099ddf9 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -119,8 +119,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel: if inpt.ndim < 2: raise ValueError("Need a batch of one hot labels") - output = inpt.clone() - output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam)) + output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) return features.OneHotLabel.wrap_like(inpt, output) @@ -134,8 +133,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: expected_ndim = 5 if isinstance(inpt, features.Video) else 4 if inpt.ndim < expected_ndim: raise ValueError("The transform expects a batched input") - output = inpt.clone() - output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam)) + output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) if isinstance(inpt, (features.Image, features.Video)): output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] @@ -241,11 +239,12 @@ def _copy_paste( if blending: paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0]) + inverse_paste_alpha_mask = ~paste_alpha_mask # Copy-paste images: - image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask) + image = (image * inverse_paste_alpha_mask).add_(paste_image * paste_alpha_mask) # Copy-paste masks: - masks = masks * (~paste_alpha_mask) + masks = masks * inverse_paste_alpha_mask non_all_zero_masks = masks.sum((-1, -2)) > 0 masks = masks[non_all_zero_masks] From b9dc49b9fdf0e493d7cacb47aad3fab8ac397c18 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Oct 2022 10:24:43 +0100 Subject: [PATCH 3/3] Apply nits. --- torchvision/prototype/transforms/_augment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index c6a8099ddf9..b4834e47f88 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -239,9 +239,9 @@ def _copy_paste( if blending: paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0]) - inverse_paste_alpha_mask = ~paste_alpha_mask + inverse_paste_alpha_mask = paste_alpha_mask.logical_not() # Copy-paste images: - image = (image * inverse_paste_alpha_mask).add_(paste_image * paste_alpha_mask) + image = image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask)) # Copy-paste masks: masks = masks * inverse_paste_alpha_mask