-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[prototype] Speed up Augment Transform Classes #6835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9440d3c
9b9b75e
7cc8ec9
b9dc49b
fa82d63
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,24 +40,22 @@ def __init__( | |
raise ValueError("Scale should be between 0 and 1") | ||
self.scale = scale | ||
self.ratio = ratio | ||
self.value = value | ||
if isinstance(value, (int, float)): | ||
self.value = [value] | ||
elif isinstance(value, str): | ||
self.value = None | ||
elif isinstance(value, tuple): | ||
self.value = list(value) | ||
else: | ||
self.value = value | ||
self.inplace = inplace | ||
|
||
self._log_ratio = torch.log(torch.tensor(self.ratio)) | ||
|
||
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: | ||
img_c, img_h, img_w = query_chw(flat_inputs) | ||
|
||
if isinstance(self.value, (int, float)): | ||
value = [self.value] | ||
elif isinstance(self.value, str): | ||
value = None | ||
elif isinstance(self.value, tuple): | ||
value = list(self.value) | ||
else: | ||
value = self.value | ||
|
||
if value is not None and not (len(value) in (1, img_c)): | ||
if self.value is not None and not (len(self.value) in (1, img_c)): | ||
raise ValueError( | ||
f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)" | ||
) | ||
|
@@ -79,10 +77,10 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: | |
if not (h < img_h and w < img_w): | ||
continue | ||
|
||
if value is None: | ||
if self.value is None: | ||
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() | ||
else: | ||
v = torch.tensor(value)[:, None, None] | ||
v = torch.tensor(self.value)[:, None, None] | ||
|
||
i = torch.randint(0, img_h - h + 1, size=(1,)).item() | ||
j = torch.randint(0, img_w - w + 1, size=(1,)).item() | ||
|
@@ -121,8 +119,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: | |
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel: | ||
if inpt.ndim < 2: | ||
raise ValueError("Need a batch of one hot labels") | ||
output = inpt.clone() | ||
output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam)) | ||
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we saw in previous optimizations (for example #6821), removing clone in favour of moving the copy on the first multiplication has speed benefits. Below I benchmark the following two alternatives for various inputs (images and labels) to showcase this: def withclone(inpt, lam=0.5):
output = inpt.clone()
return output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam))
def withoutclone(inpt, lam=0.5):
return inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) Below we benchmark the 2 above methods alone:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because it took me a while to figure out why the
would modify |
||
return features.OneHotLabel.wrap_like(inpt, output) | ||
|
||
|
||
|
@@ -136,8 +133,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | |
expected_ndim = 5 if isinstance(inpt, features.Video) else 4 | ||
if inpt.ndim < expected_ndim: | ||
raise ValueError("The transform expects a batched input") | ||
output = inpt.clone() | ||
output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam)) | ||
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) | ||
|
||
if isinstance(inpt, (features.Image, features.Video)): | ||
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] | ||
|
@@ -243,11 +239,12 @@ def _copy_paste( | |
if blending: | ||
paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0]) | ||
|
||
inverse_paste_alpha_mask = paste_alpha_mask.logical_not() | ||
# Copy-paste images: | ||
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask) | ||
image = image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask)) | ||
|
||
# Copy-paste masks: | ||
masks = masks * (~paste_alpha_mask) | ||
masks = masks * inverse_paste_alpha_mask | ||
non_all_zero_masks = masks.sum((-1, -2)) > 0 | ||
masks = masks[non_all_zero_masks] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving
value
handling from run-time to constructor. This is more code-quality than perf.