Skip to content

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Oct 13, 2022

This PR is only to show-case the differences against the previous code. It shouldn't be merged as we should do this optimization directly to V2.

Results:

adjust_brightness - Winner: v2(device=cpu, dtype=torch.uint8) - v1: 0.0017781441507395356, v2: 0.0015310867200605573 - Diff: -13.89%
adjust_brightness - Winner: v2(device=cpu, dtype=torch.float32) - v1: 0.0008344144199509174, v2: 0.0005974295809864998 - Diff: -28.40%
adjust_brightness - Winner: v2(device=cuda, dtype=torch.uint8) - v1: 5.151713805971667e-05, v2: 4.7857200680300596e-05 - Diff: -7.10%
adjust_brightness - Winner: v2(device=cuda, dtype=torch.float32) - v1: 4.2962464748416095e-05, v2: 3.8983459910377864e-05 - Diff: -9.26%
adjust_contrast - Winner: v2(device=cpu, dtype=torch.uint8) - v1: 0.0019489222508855164, v2: 0.001805254714563489 - Diff: -7.37%
adjust_contrast - Winner: v2(device=cpu, dtype=torch.float32) - v1: 0.0009198985202237964, v2: 0.0007649726495146752 - Diff: -16.84%
adjust_contrast - Winner: v2(device=cuda, dtype=torch.uint8) - v1: 0.00012647993885912, v2: 0.0001235444841440767 - Diff: -2.32%
adjust_contrast - Winner: v2(device=cuda, dtype=torch.float32) - v1: 9.984084800817073e-05, v2: 9.742333088070155e-05 - Diff: -2.42%
adjust_saturation - Winner: v2(device=cpu, dtype=torch.uint8) - v1: 0.0019948067096993327, v2: 0.0018562771938741207 - Diff: -6.94%
adjust_saturation - Winner: v2(device=cpu, dtype=torch.float32) - v1: 0.0009555038996040821, v2: 0.0007746961107477546 - Diff: -18.92%
adjust_saturation - Winner: v2(device=cuda, dtype=torch.uint8) - v1: 0.0001049026179825887, v2: 0.00010217939910944551 - Diff: -2.60%
adjust_saturation - Winner: v2(device=cuda, dtype=torch.float32) - v1: 8.710524649359286e-05, v2: 8.433246007189155e-05 - Diff: -3.18%
adjust_sharpness - Winner: v2(device=cpu, dtype=torch.uint8) - v1: 0.007391045135445893, v2: 0.006958624790422618 - Diff: -5.85%
adjust_sharpness - Winner: v2(device=cpu, dtype=torch.float32) - v1: 0.004595479456475005, v2: 0.003972332294797525 - Diff: -13.56%
adjust_sharpness - Winner: v2(device=cuda, dtype=torch.uint8) - v1: 0.0001897168857976794, v2: 0.000187589162029326 - Diff: -1.12%
adjust_sharpness - Winner: v2(device=cuda, dtype=torch.float32) - v1: 0.00015073937014676631, v2: 0.00014734733151271939 - Diff: -2.25%
Benchmark Script
import torch
import torch.utils.benchmark as benchmark

from torch import Tensor
from torchvision.transforms.functional_tensor import _assert_image_tensor, _assert_channels, get_dimensions, rgb_to_grayscale, _blurred_degenerate_image
from torchvision.transforms import functional_tensor as V1



def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
    if ratio == 1.0:
        return img1
    ratio = float(ratio)
    bound = 1.0 if img1.is_floating_point() else 255.0

    if img2.is_floating_point():
        # Since img2 is float, we can do in-place ops on it. It's a throw-away tensor.
        # Our strategy is to convert img1 to float and copy it to avoid in-place modifications,
        # update img2 in-place and add it on the result with an in-place op.
        result = img1 * ratio
        img2.mul_(1.0 - ratio)
        result.add_(img2)
    else:
        # Since img2 is not float, we can't do in-place ops on it.
        # To minimize copies/adds/muls we first convert img1 to float by multiplying it with ratio/(1-ratio).
        # This permits us to add img2 in-place to it, without further copies.
        # To ensure we have the correct result at the end, we multiply in-place with (1-ratio).
        result = img1 * (ratio / (1.0 - ratio))
        result.add_(img2).mul_(1.0 - ratio)

    return result.clamp_(0, bound).to(img1.dtype)


############ COPY PASTE FROM functional_tensor.py
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
    if brightness_factor < 0:
        raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")

    _assert_image_tensor(img)

    _assert_channels(img, [1, 3])

    return _blend(img, torch.zeros_like(img), brightness_factor)


def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
    if contrast_factor < 0:
        raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")

    _assert_image_tensor(img)

    _assert_channels(img, [3, 1])
    c = get_dimensions(img)[0]
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    if c == 3:
        mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
    else:
        mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)

    return _blend(img, mean, contrast_factor)


def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
    if saturation_factor < 0:
        raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")

    _assert_image_tensor(img)

    _assert_channels(img, [1, 3])

    if get_dimensions(img)[0] == 1:  # Match PIL behaviour
        return img

    return _blend(img, rgb_to_grayscale(img), saturation_factor)


def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
    if sharpness_factor < 0:
        raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")

    _assert_image_tensor(img)

    _assert_channels(img, [1, 3])

    if img.size(-1) <= 2 or img.size(-2) <= 2:
        return img

    return _blend(img, _blurred_degenerate_image(img), sharpness_factor)


############### BENCHMARKS

def bench_torch(fn, data, factor, runtime=10):
    for _ in range(10):
        fn(data, factor)
    results = benchmark.Timer(
                stmt=f"fn(data, {factor})",
                globals={
                    "data": data,
                    "fn": fn,
                },
                num_threads=torch.get_num_threads(),
            ).blocked_autorange(min_run_time=runtime)
    return results.median


data = torch.randint(0, 256, (10, 3, 128, 128))
data_f = data / 255.0
factor = 1.5
bench_fn = bench_torch


devices = ["cpu"]
if torch.cuda.is_available():
    devices.append("cuda")
for fn_name in ["adjust_brightness", "adjust_contrast", "adjust_saturation", "adjust_sharpness"]:
    fn_v1 = V1.__dict__[fn_name]
    fn_v2 = locals()[fn_name]
    for device in devices:
        for dtype, img in [(torch.uint8, data.to(device=device)), (torch.float32, data_f.to(device=device))]:
            v1_time = bench_fn(fn_v1, img, factor)
            v2_time = bench_fn(fn_v2, img, factor)

            winner = "v1" if v1_time < v2_time else "v2"
            print(f"{fn_name} - Winner: {winner}(device={device}, dtype={dtype}) - v1: {v1_time}, v2: {v2_time} - Diff: {100*(v2_time-v1_time)/v1_time:.2f}%")
            try:
                torch.testing.assert_close(fn_v1(img, factor), fn_v2(img, factor))
            except Exception:
                print("WARNING: fv1(x) != fv2(x)")

cc @vfdev-5 @bjuncek @pmeier

@datumbox datumbox marked this pull request as draft October 13, 2022 18:21
# Our strategy is to convert img1 to float and copy it to avoid in-place modifications,
# update img2 in-place and add it on the result with an in-place op.
result = img1 * ratio
img2.mul_(1.0 - ratio)
Copy link
Contributor

@vfdev-5 vfdev-5 Oct 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this is safe to update inplace one of the inputs, img2.mul_(1.0 - ratio) ?
Maybe, we could do the same trick as in else branch:

result = (img1 * ratio).div_(1.0 - ratio)
result.add_(img2).mul_(1.0 - ratio)

Another point is if ratio is 1.0 we may have nan without a reason...

Copy link
Contributor Author

@datumbox datumbox Oct 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the nan issue on the latest commit.

Your proposal should be slower because it performs an extra division in the entire tensor. Happy to investigate more alternatives. BTW this is an alternative implementation that is faster than V1 but slower than this PR:

def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
    if ratio == 1.0:
        return img1
    ratio = float(ratio)
    bound = 1.0 if img1.is_floating_point() else 255.0

    result = img1 * (ratio / (1.0 - ratio))
    result.add_(img2).mul_(1.0 - ratio).clamp_(0, bound)

    return result.to(img1.dtype)

Benchmarks:

adjust_brightness - Winner: v2(device=cpu, dtype=torch.uint8) - v1: 0.0017525686789304018, v2: 0.0015204967744648456 - Diff: -13.24%
adjust_brightness - Winner: v2(device=cpu, dtype=torch.float32) - v1: 0.0008267542696557939, v2: 0.0005970064329449087 - Diff: -27.79%
adjust_brightness - Winner: v2(device=cuda, dtype=torch.uint8) - v1: 5.1363299996592105e-05, v2: 4.837187579832971e-05 - Diff: -5.82%
adjust_brightness - Winner: v2(device=cuda, dtype=torch.float32) - v1: 4.312360244803131e-05, v2: 4.006389770656824e-05 - Diff: -7.10%
adjust_contrast - Winner: v2(device=cpu, dtype=torch.uint8) - v1: 0.001936913060490042, v2: 0.0018658265587873756 - Diff: -3.67%
adjust_contrast - Winner: v2(device=cpu, dtype=torch.float32) - v1: 0.0009221774595789611, v2: 0.0008399060717783868 - Diff: -8.92%
adjust_contrast - Winner: v2(device=cuda, dtype=torch.uint8) - v1: 0.0001281447554938495, v2: 0.00012599090137518942 - Diff: -1.68%
adjust_contrast - Winner: v2(device=cuda, dtype=torch.float32) - v1: 0.00010212179413065314, v2: 9.956091595813633e-05 - Diff: -2.51%
adjust_saturation - Winner: v2(device=cpu, dtype=torch.uint8) - v1: 0.001968699339777231, v2: 0.001849049050360918 - Diff: -6.08%
adjust_saturation - Winner: v2(device=cpu, dtype=torch.float32) - v1: 0.0009527031611651182, v2: 0.0008303909911774099 - Diff: -12.84%
adjust_saturation - Winner: v2(device=cuda, dtype=torch.uint8) - v1: 0.00010614430147688837, v2: 0.00010406337701715528 - Diff: -1.96%
adjust_saturation - Winner: v2(device=cuda, dtype=torch.float32) - v1: 8.796596352476627e-05, v2: 8.61313920468092e-05 - Diff: -2.09%
adjust_sharpness - Winner: v2(device=cpu, dtype=torch.uint8) - v1: 0.007410917698871344, v2: 0.00669603182002902 - Diff: -9.65%
adjust_sharpness - Winner: v2(device=cpu, dtype=torch.float32) - v1: 0.004427262460812926, v2: 0.003999999698717147 - Diff: -9.65%
adjust_sharpness - Winner: v2(device=cuda, dtype=torch.uint8) - v1: 0.00018968257401138545, v2: 0.00018811142595950514 - Diff: -0.83%
adjust_sharpness - Winner: v2(device=cuda, dtype=torch.float32) - v1: 0.00015139597153756768, v2: 0.00014969582995399833 - Diff: -1.12%

@datumbox
Copy link
Contributor Author

Superseded by the work at #6784

@datumbox datumbox closed this Oct 20, 2022
@datumbox datumbox deleted the prototype/opt_blend branch October 20, 2022 10:50
@datumbox datumbox added module: transforms Perf For performance improvements prototype labels Oct 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants