Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve perf on convert_image_dtype and add tests #6795

Merged
merged 6 commits into from
Oct 20, 2022

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Oct 19, 2022

The improvements come from using inplace operations where possible.

benchmark script

import itertools
import pathlib
import pickle

import torch
from torch.utils import benchmark
import functools

from torchvision.prototype.transforms import functional as F

description = "PR"  # "main", "PR"


def make_inputs(*, input_dtype, output_dtype, device, shape=(3, 512, 512)):
    if input_dtype.is_floating_point:
        image = torch.rand(shape, dtype=input_dtype, device=device)
    else:
        image = torch.randint(0, torch.iinfo(input_dtype).max + 1, shape, dtype=input_dtype, device=device)
    return image, output_dtype


sub_labels_and_input_fns = [
    ("float to float", functools.partial(make_inputs, input_dtype=torch.float32, output_dtype=torch.float64)),
    ("float to   int", functools.partial(make_inputs, input_dtype=torch.float32, output_dtype=torch.uint8)),
    ("  int to float", functools.partial(make_inputs, input_dtype=torch.uint8, output_dtype=torch.float32)),
    ("  int to   int (down)", functools.partial(make_inputs, input_dtype=torch.int32, output_dtype=torch.uint8)),
    ("  int to   int (up)", functools.partial(make_inputs, input_dtype=torch.uint8, output_dtype=torch.int32)),
]


timers = [
    benchmark.Timer(
        stmt="convert_image_dtype(*inputs)",
        globals=dict(
            convert_image_dtype=F.convert_image_dtype,
            inputs=inputs_fn(device=device),
        ),
        label="convert_image_dtype perf improvements",
        sub_label=f"{device:4} / {sub_label}",
        description=description,
        num_threads=num_threads,
    )
    for (sub_label, inputs_fn), device in itertools.product(sub_labels_and_input_fns, ["cpu", "cuda"])
    for num_threads in ([1, 2, 4] if device == "cpu" else [1])
]

measurements = [timer.blocked_autorange(min_run_time=5) for timer in timers]


with open(f"{description}.measurements", "wb") as fh:
    pickle.dump(measurements, fh)

measurements = []
for file in pathlib.Path(".").glob("*.measurements"):
    with open(file, "rb") as fh:
        measurements.extend(pickle.load(fh))

comparison = benchmark.Compare(measurements)
comparison.trim_significant_figures()
comparison.print()

[----- convert_image_dtype perf improvements ------]
                                    |  main  |   PR 
1 threads: -----------------------------------------
      cpu  / float to float         |    88  |    82
      cuda / float to float         |    42  |    42
      cpu  / float to   int         |   380  |   360
      cuda / float to   int         |    46  |    50
      cpu  /   int to float         |   136  |   130
      cuda /   int to float         |    47  |    47
      cpu  /   int to   int (down)  |  1050  |  1070
      cuda /   int to   int (down)  |    44  |    44
      cpu  /   int to   int (up)    |   120  |    88
      cuda /   int to   int (up)    |    46  |    46
2 threads: -----------------------------------------
      cpu  / float to float         |    53  |    46
      cpu  / float to   int         |   210  |   199
      cpu  /   int to float         |    82  |    76
      cpu  /   int to   int (down)  |   560  |   546
      cpu  /   int to   int (up)    |    74  |    55
4 threads: -----------------------------------------
      cpu  / float to float         |    31  |    27
      cpu  / float to   int         |   115  |   108
      cpu  /   int to float         |    51  |    45
      cpu  /   int to   int (down)  |   293  |   286
      cpu  /   int to   int (up)    |    47  |    35

Times are in microseconds (us).

The branches that are improved are

  • float to float
  • float to int
  • int to int (up)

Of these float to int is the most interesting for us, since we regularly to torch.uint8 to torch.float32 before we normalize. With this patch, we get the following diff when profiling with @vfdev-5's benchmark scripts

-      2000    0.047    0.000    0.095    0.000 /home/philip/git/pytorch/torchvision/torchvision/transforms/functional_tensor.py:68(convert_image_dtype)
+      2000    0.005    0.000    0.072    0.000 /home/philip/git/pytorch/torchvision/torchvision/prototype/transforms/functional/_type_conversion.py:46(convert_image_dtype)

cc @vfdev-5 @datumbox @bjuncek

Comment on lines +317 to +323
@pytest.mark.parametrize(
("info", "args_kwargs"),
make_info_args_kwargs_params(
next(info for info in KERNEL_INFOS if info.kernel is F.convert_image_dtype),
args_kwargs_fn=lambda info: info.sample_inputs_fn(),
),
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is rather convoluted to get the sample inputs for a single kernel. I'll refactor later since this is low priority right now.

Copy link
Collaborator Author

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another round of benchmarks after the new commits. Benchmark removed CUDA test and only tested one thread. On the flip side, measurements are now running longer to reduce the noise.

benchmark script

import pathlib
import pickle

import torch
from torch.utils import benchmark
import functools

from torchvision.prototype.transforms import functional as F

description = "PR"  # "main", "PR"


def make_inputs(*, input_dtype, output_dtype, shape=(3, 512, 512)):
    if input_dtype.is_floating_point:
        image = torch.rand(shape, dtype=input_dtype)
    else:
        image = torch.randint(0, torch.iinfo(input_dtype).max + 1, shape, dtype=input_dtype)
    return image, output_dtype


sub_labels_and_input_fns = [
    ("float to float", functools.partial(make_inputs, input_dtype=torch.float32, output_dtype=torch.float64)),
    ("float to   int", functools.partial(make_inputs, input_dtype=torch.float32, output_dtype=torch.uint8)),
    ("  int to float", functools.partial(make_inputs, input_dtype=torch.uint8, output_dtype=torch.float32)),
    ("  int to   int (down)", functools.partial(make_inputs, input_dtype=torch.int32, output_dtype=torch.uint8)),
    ("  int to   int (up)", functools.partial(make_inputs, input_dtype=torch.uint8, output_dtype=torch.int32)),
]


timers = [
    benchmark.Timer(
        stmt="convert_image_dtype(*inputs)",
        globals=dict(
            convert_image_dtype=F.convert_image_dtype,
            inputs=inputs_fn(),
        ),
        label="convert_image_dtype perf improvements",
        sub_label=sub_label,
        description=description,
        num_threads=1,
    )
    for sub_label, inputs_fn in sub_labels_and_input_fns
]

measurements = [timer.blocked_autorange(min_run_time=15) for timer in timers]


with open(f"{description}.measurements", "wb") as fh:
    pickle.dump(measurements, fh)

measurements = []
for file in pathlib.Path(".").glob("*.measurements"):
    with open(file, "rb") as fh:
        measurements.extend(pickle.load(fh))

comparison = benchmark.Compare(measurements)
comparison.trim_significant_figures()
comparison.print()

[- convert_image_dtype perf improvements --]
                             |  main  |   PR
1 threads: ---------------------------------
      float to float         |    90  |   83
      float to   int         |   380  |  380
        int to float         |   138  |  134
        int to   int (down)  |  1100  |  402
        int to   int (up)    |   127  |   92

Times are in microseconds (us).
  • "float to float", "int to float", "int to int (up)" did not change from the last benchmarks and are still faster
  • "int to int (down)" now uses bit shifts and is 3x faster

if output_dtype.is_floating_point:
return value
else:
return int(decimal.Decimal(value) * torch.iinfo(output_dtype).max)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives us arbitrary floating point precision for the intermediate calculations, which is what we want for the reference function. You can see from the xfails I needed to add below, that we need this in some cases.

Comment on lines +2051 to +2062
condition=lambda args_kwargs: (
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
and not args_kwargs.kwargs["dtype"].is_floating_point
)
or (
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
and args_kwargs.kwargs["dtype"] == torch.int64
)
or (
args_kwargs.args[0].dtype in {torch.int32, torch.int64}
and args_kwargs.kwargs["dtype"] == torch.float16
),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to open an issue soon detailing what is happening in these cases and how we could mitigate it.

Comment on lines +110 to +114
# The bitshift kernel is not vectorized
# https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322
# This results in the multiplication actually being faster.
# TODO: If the bitshift kernel is optimized in core, replace the computation below with
# `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)`
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per comment. The same applies to the bitwise_right_shift kernel in the branch above, but that is still much faster than the division we had before.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one question:

Edit: Lol, Github lost my comment. I was asking if we are confident that the use of bitwise_right_shift produces identical results to the previous implementation.

@pmeier
Copy link
Collaborator Author

pmeier commented Oct 20, 2022

I was asking if we are confident that the use of bitwise_right_shift produces identical results to the previous implementation.

Yes, I am. I've added reference tests just to make sure I'm not introducing anything here. If you look at them, in there I'm actually using the old idiom of multiplying or dividing by the factors.

Comment on lines +1991 to +1996
if input_max_value > output_max_value:
factor = (input_max_value + 1) // (output_max_value + 1)
return value // factor
else:
factor = (output_max_value + 1) // (input_max_value + 1)
return value * factor
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pointer for my comment above.

@pmeier pmeier merged commit 211563f into pytorch:main Oct 20, 2022
@pmeier pmeier deleted the convert-image-dtype branch October 20, 2022 12:14
facebook-github-bot pushed a commit that referenced this pull request Oct 21, 2022
Summary:
* improve perf on convert_image_dtype and add tests

* add reference tests

* use bitshifts for int to int

* revert bitshifts for int to int upscale

* fix warning ignore

Reviewed By: YosuaMichael

Differential Revision: D40588162

fbshipit-source-id: 4f1c564f94f75ff37979c123a416b043b4c9ec14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants