Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,27 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:


def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
if ratio == 1.0:
return img1
ratio = float(ratio)
bound = 1.0 if img1.is_floating_point() else 255.0
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)

if img2.is_floating_point():
# Since img2 is float, we can do in-place ops on it. It's a throw-away tensor.
# Our strategy is to convert img1 to float and copy it to avoid in-place modifications,
# update img2 in-place and add it on the result with an in-place op.
result = img1 * ratio
img2.mul_(1.0 - ratio)
Copy link
Contributor

@vfdev-5 vfdev-5 Oct 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this is safe to update inplace one of the inputs, img2.mul_(1.0 - ratio) ?
Maybe, we could do the same trick as in else branch:

result = (img1 * ratio).div_(1.0 - ratio)
result.add_(img2).mul_(1.0 - ratio)

Another point is if ratio is 1.0 we may have nan without a reason...

Copy link
Contributor Author

@datumbox datumbox Oct 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the nan issue on the latest commit.

Your proposal should be slower because it performs an extra division in the entire tensor. Happy to investigate more alternatives. BTW this is an alternative implementation that is faster than V1 but slower than this PR:

def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
    if ratio == 1.0:
        return img1
    ratio = float(ratio)
    bound = 1.0 if img1.is_floating_point() else 255.0

    result = img1 * (ratio / (1.0 - ratio))
    result.add_(img2).mul_(1.0 - ratio).clamp_(0, bound)

    return result.to(img1.dtype)

Benchmarks:

adjust_brightness - Winner: v2(device=cpu, dtype=torch.uint8) - v1: 0.0017525686789304018, v2: 0.0015204967744648456 - Diff: -13.24%
adjust_brightness - Winner: v2(device=cpu, dtype=torch.float32) - v1: 0.0008267542696557939, v2: 0.0005970064329449087 - Diff: -27.79%
adjust_brightness - Winner: v2(device=cuda, dtype=torch.uint8) - v1: 5.1363299996592105e-05, v2: 4.837187579832971e-05 - Diff: -5.82%
adjust_brightness - Winner: v2(device=cuda, dtype=torch.float32) - v1: 4.312360244803131e-05, v2: 4.006389770656824e-05 - Diff: -7.10%
adjust_contrast - Winner: v2(device=cpu, dtype=torch.uint8) - v1: 0.001936913060490042, v2: 0.0018658265587873756 - Diff: -3.67%
adjust_contrast - Winner: v2(device=cpu, dtype=torch.float32) - v1: 0.0009221774595789611, v2: 0.0008399060717783868 - Diff: -8.92%
adjust_contrast - Winner: v2(device=cuda, dtype=torch.uint8) - v1: 0.0001281447554938495, v2: 0.00012599090137518942 - Diff: -1.68%
adjust_contrast - Winner: v2(device=cuda, dtype=torch.float32) - v1: 0.00010212179413065314, v2: 9.956091595813633e-05 - Diff: -2.51%
adjust_saturation - Winner: v2(device=cpu, dtype=torch.uint8) - v1: 0.001968699339777231, v2: 0.001849049050360918 - Diff: -6.08%
adjust_saturation - Winner: v2(device=cpu, dtype=torch.float32) - v1: 0.0009527031611651182, v2: 0.0008303909911774099 - Diff: -12.84%
adjust_saturation - Winner: v2(device=cuda, dtype=torch.uint8) - v1: 0.00010614430147688837, v2: 0.00010406337701715528 - Diff: -1.96%
adjust_saturation - Winner: v2(device=cuda, dtype=torch.float32) - v1: 8.796596352476627e-05, v2: 8.61313920468092e-05 - Diff: -2.09%
adjust_sharpness - Winner: v2(device=cpu, dtype=torch.uint8) - v1: 0.007410917698871344, v2: 0.00669603182002902 - Diff: -9.65%
adjust_sharpness - Winner: v2(device=cpu, dtype=torch.float32) - v1: 0.004427262460812926, v2: 0.003999999698717147 - Diff: -9.65%
adjust_sharpness - Winner: v2(device=cuda, dtype=torch.uint8) - v1: 0.00018968257401138545, v2: 0.00018811142595950514 - Diff: -0.83%
adjust_sharpness - Winner: v2(device=cuda, dtype=torch.float32) - v1: 0.00015139597153756768, v2: 0.00014969582995399833 - Diff: -1.12%

result.add_(img2)
else:
# Since img2 is not float, we can't do in-place ops on it.
# To minimize copies/adds/muls we first convert img1 to float by multiplying it with ratio/(1-ratio).
# This permits us to add img2 in-place to it, without further copies.
# To ensure we have the correct result at the end, we multiply in-place with (1-ratio).
result = img1 * (ratio / (1.0 - ratio))
result.add_(img2).mul_(1.0 - ratio)

return result.clamp_(0, bound).to(img1.dtype)


def _rgb2hsv(img: Tensor) -> Tensor:
Expand Down