diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 20b76fbf079..9ef16c13cbe 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -262,9 +262,27 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: + if ratio == 1.0: + return img1 ratio = float(ratio) bound = 1.0 if img1.is_floating_point() else 255.0 - return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype) + + if img2.is_floating_point(): + # Since img2 is float, we can do in-place ops on it. It's a throw-away tensor. + # Our strategy is to convert img1 to float and copy it to avoid in-place modifications, + # update img2 in-place and add it on the result with an in-place op. + result = img1 * ratio + img2.mul_(1.0 - ratio) + result.add_(img2) + else: + # Since img2 is not float, we can't do in-place ops on it. + # To minimize copies/adds/muls we first convert img1 to float by multiplying it with ratio/(1-ratio). + # This permits us to add img2 in-place to it, without further copies. + # To ensure we have the correct result at the end, we multiply in-place with (1-ratio). + result = img1 * (ratio / (1.0 - ratio)) + result.add_(img2).mul_(1.0 - ratio) + + return result.clamp_(0, bound).to(img1.dtype) def _rgb2hsv(img: Tensor) -> Tensor: