-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
improve perf on convert_image_dtype and add tests #6795
Conversation
@pytest.mark.parametrize( | ||
("info", "args_kwargs"), | ||
make_info_args_kwargs_params( | ||
next(info for info in KERNEL_INFOS if info.kernel is F.convert_image_dtype), | ||
args_kwargs_fn=lambda info: info.sample_inputs_fn(), | ||
), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is rather convoluted to get the sample inputs for a single kernel. I'll refactor later since this is low priority right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another round of benchmarks after the new commits. Benchmark removed CUDA test and only tested one thread. On the flip side, measurements are now running longer to reduce the noise.
benchmark script
import pathlib
import pickle
import torch
from torch.utils import benchmark
import functools
from torchvision.prototype.transforms import functional as F
description = "PR" # "main", "PR"
def make_inputs(*, input_dtype, output_dtype, shape=(3, 512, 512)):
if input_dtype.is_floating_point:
image = torch.rand(shape, dtype=input_dtype)
else:
image = torch.randint(0, torch.iinfo(input_dtype).max + 1, shape, dtype=input_dtype)
return image, output_dtype
sub_labels_and_input_fns = [
("float to float", functools.partial(make_inputs, input_dtype=torch.float32, output_dtype=torch.float64)),
("float to int", functools.partial(make_inputs, input_dtype=torch.float32, output_dtype=torch.uint8)),
(" int to float", functools.partial(make_inputs, input_dtype=torch.uint8, output_dtype=torch.float32)),
(" int to int (down)", functools.partial(make_inputs, input_dtype=torch.int32, output_dtype=torch.uint8)),
(" int to int (up)", functools.partial(make_inputs, input_dtype=torch.uint8, output_dtype=torch.int32)),
]
timers = [
benchmark.Timer(
stmt="convert_image_dtype(*inputs)",
globals=dict(
convert_image_dtype=F.convert_image_dtype,
inputs=inputs_fn(),
),
label="convert_image_dtype perf improvements",
sub_label=sub_label,
description=description,
num_threads=1,
)
for sub_label, inputs_fn in sub_labels_and_input_fns
]
measurements = [timer.blocked_autorange(min_run_time=15) for timer in timers]
with open(f"{description}.measurements", "wb") as fh:
pickle.dump(measurements, fh)
measurements = []
for file in pathlib.Path(".").glob("*.measurements"):
with open(file, "rb") as fh:
measurements.extend(pickle.load(fh))
comparison = benchmark.Compare(measurements)
comparison.trim_significant_figures()
comparison.print()
[- convert_image_dtype perf improvements --]
| main | PR
1 threads: ---------------------------------
float to float | 90 | 83
float to int | 380 | 380
int to float | 138 | 134
int to int (down) | 1100 | 402
int to int (up) | 127 | 92
Times are in microseconds (us).
- "float to float", "int to float", "int to int (up)" did not change from the last benchmarks and are still faster
- "int to int (down)" now uses bit shifts and is 3x faster
if output_dtype.is_floating_point: | ||
return value | ||
else: | ||
return int(decimal.Decimal(value) * torch.iinfo(output_dtype).max) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gives us arbitrary floating point precision for the intermediate calculations, which is what we want for the reference function. You can see from the xfails I needed to add below, that we need this in some cases.
condition=lambda args_kwargs: ( | ||
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16} | ||
and not args_kwargs.kwargs["dtype"].is_floating_point | ||
) | ||
or ( | ||
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16} | ||
and args_kwargs.kwargs["dtype"] == torch.int64 | ||
) | ||
or ( | ||
args_kwargs.args[0].dtype in {torch.int32, torch.int64} | ||
and args_kwargs.kwargs["dtype"] == torch.float16 | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to open an issue soon detailing what is happening in these cases and how we could mitigate it.
# The bitshift kernel is not vectorized | ||
# https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322 | ||
# This results in the multiplication actually being faster. | ||
# TODO: If the bitshift kernel is optimized in core, replace the computation below with | ||
# `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per comment. The same applies to the bitwise_right_shift
kernel in the branch above, but that is still much faster than the division we had before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just one question:
Edit: Lol, Github lost my comment. I was asking if we are confident that the use of bitwise_right_shift
produces identical results to the previous implementation.
Yes, I am. I've added reference tests just to make sure I'm not introducing anything here. If you look at them, in there I'm actually using the old idiom of multiplying or dividing by the factors. |
if input_max_value > output_max_value: | ||
factor = (input_max_value + 1) // (output_max_value + 1) | ||
return value // factor | ||
else: | ||
factor = (output_max_value + 1) // (input_max_value + 1) | ||
return value * factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pointer for my comment above.
Summary: * improve perf on convert_image_dtype and add tests * add reference tests * use bitshifts for int to int * revert bitshifts for int to int upscale * fix warning ignore Reviewed By: YosuaMichael Differential Revision: D40588162 fbshipit-source-id: 4f1c564f94f75ff37979c123a416b043b4c9ec14
The improvements come from using inplace operations where possible.
benchmark script
The branches that are improved are
Of these float to int is the most interesting for us, since we regularly to
torch.uint8
totorch.float32
before we normalize. With this patch, we get the following diff when profiling with @vfdev-5's benchmark scriptscc @vfdev-5 @datumbox @bjuncek