Skip to content

Commit

Permalink
[prototype] Speed up adjust_sharpness_image_tensor (#6930)
Browse files Browse the repository at this point in the history
* Speed up `adjust_sharpness_image_tensor`

* Add a comment
  • Loading branch information
datumbox authored Nov 8, 2022
1 parent bf58902 commit 7a7ab7e
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch.nn.functional import conv2d
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT

Expand Down Expand Up @@ -111,6 +112,8 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
if image.numel() == 0 or height <= 2 or width <= 2:
return image

bound = _FT._max_value(image.dtype)
fp = image.is_floating_point()
shape = image.shape

if image.ndim > 4:
Expand All @@ -119,7 +122,30 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
else:
needs_unsquash = False

output = _blend(image, _FT._blurred_degenerate_image(image), sharpness_factor)
# The following is a normalized 3x3 kernel with 1s in the edges and a 5 in the middle.
kernel_dtype = image.dtype if fp else torch.float32
a, b = 1.0 / 13.0, 5.0 / 13.0
kernel = torch.tensor([[a, a, a], [a, b, a], [a, a, a]], dtype=kernel_dtype, device=image.device)
kernel = kernel.expand(num_channels, 1, 3, 3)

# We copy and cast at the same time to avoid modifications on the original data
output = image.to(dtype=kernel_dtype, copy=True)
blurred_degenerate = conv2d(output, kernel, groups=num_channels)
if not fp:
# it is better to round before cast
blurred_degenerate = blurred_degenerate.round_()

# Create a view on the underlying output while pointing at the same data. We do this to avoid indexing twice.
view = output[..., 1:-1, 1:-1]

# We speed up blending by minimizing flops and doing in-place. The 2 blend options are mathematically equivalent:
# x+(1-r)*(y-x) = x + (1-r)*y - (1-r)*x = x*r + y*(1-r)
view.add_(blurred_degenerate.sub_(view), alpha=(1.0 - sharpness_factor))

# The actual data of ouput have been modified by the above. We only need to clamp and cast now.
output = output.clamp_(0, bound)
if not fp:
output = output.to(image.dtype)

if needs_unsquash:
output = output.reshape(shape)
Expand Down

0 comments on commit 7a7ab7e

Please sign in to comment.