Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,22 @@ def __init__(
raise ValueError("Scale should be between 0 and 1")
self.scale = scale
self.ratio = ratio
self.value = value
if isinstance(value, (int, float)):
self.value = [value]
elif isinstance(value, str):
self.value = None
elif isinstance(value, tuple):
self.value = list(value)
else:
self.value = value
Comment on lines +43 to +50
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving value handling from run-time to constructor. This is more code-quality than perf.

self.inplace = inplace

self._log_ratio = torch.log(torch.tensor(self.ratio))

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(flat_inputs)

if isinstance(self.value, (int, float)):
value = [self.value]
elif isinstance(self.value, str):
value = None
elif isinstance(self.value, tuple):
value = list(self.value)
else:
value = self.value

if value is not None and not (len(value) in (1, img_c)):
if self.value is not None and not (len(self.value) in (1, img_c)):
raise ValueError(
f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
)
Expand All @@ -79,10 +77,10 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
if not (h < img_h and w < img_w):
continue

if value is None:
if self.value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(value)[:, None, None]
v = torch.tensor(self.value)[:, None, None]

i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
Expand Down Expand Up @@ -121,8 +119,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None:
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
if inpt.ndim < 2:
raise ValueError("Need a batch of one hot labels")
output = inpt.clone()
output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam))
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
Copy link
Contributor Author

@datumbox datumbox Oct 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we saw in previous optimizations (for example #6821), removing clone in favour of moving the copy on the first multiplication has speed benefits. Below I benchmark the following two alternatives for various inputs (images and labels) to showcase this:

def withclone(inpt, lam=0.5):
    output = inpt.clone()
    return output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam))


def withoutclone(inpt, lam=0.5):
    return inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))

Below we benchmark the 2 above methods alone:

[-------------- Mixing cpu torch.float32 --------------]
                         |  withclone   |  withoutclone 
1 threads: ---------------------------------------------
      (16, 3, 400, 400)  |    17200     |      15100    
      (16, 1000)         |       22     |         21    
6 threads: ---------------------------------------------
      (16, 3, 400, 400)  |    17700     |      15400    
      (16, 1000)         |       23     |         20    

Times are in microseconds (us).

[------------- Mixing cuda torch.float32 --------------]
                         |  withclone   |  withoutclone 
1 threads: ---------------------------------------------
      (16, 3, 400, 400)  |    275.3     |       226     
      (16, 1000)         |     32.8     |        26     
6 threads: ---------------------------------------------
      (16, 3, 400, 400)  |    275.9     |       226     
      (16, 1000)         |     32.6     |        26     

Times are in microseconds (us).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it took me a while to figure out why the clone is needed in the first place: without it,

inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul_(lam))

would modify inpt inplace due to inpt.mul_(lam). This is the inplace operation @datumbox talking above. inpt.roll always return a copy and thus we can use inplace operations afterwards.

return features.OneHotLabel.wrap_like(inpt, output)


Expand All @@ -136,8 +133,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
expected_ndim = 5 if isinstance(inpt, features.Video) else 4
if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")
output = inpt.clone()
output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam))
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))

if isinstance(inpt, (features.Image, features.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
Expand Down Expand Up @@ -243,11 +239,12 @@ def _copy_paste(
if blending:
paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0])

inverse_paste_alpha_mask = paste_alpha_mask.logical_not()
# Copy-paste images:
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)
image = image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask))

# Copy-paste masks:
masks = masks * (~paste_alpha_mask)
masks = masks * inverse_paste_alpha_mask
non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks]

Expand Down