Skip to content

Commit

Permalink
Adding GPU acceleration to encode_jpeg (#8391)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <nicolashug@fb.com>
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
  • Loading branch information
3 people committed Jun 13, 2024
1 parent f96c42f commit 143d078
Show file tree
Hide file tree
Showing 10 changed files with 622 additions and 20 deletions.
67 changes: 67 additions & 0 deletions benchmarks/encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import platform
import statistics

import torch
import torch.utils.benchmark as benchmark
import torchvision


def print_machine_specs():
print("Processor:", platform.processor())
print("Platform:", platform.platform())
print("Logical CPUs:", os.cpu_count())
print(f"\nCUDA device: {torch.cuda.get_device_name()}")
print(f"Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


def get_data():
transform = torchvision.transforms.Compose(
[
torchvision.transforms.PILToTensor(),
]
)
path = os.path.join(os.getcwd(), "data")
testset = torchvision.datasets.Places365(
root="./data", download=not os.path.exists(path), transform=transform, split="val"
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=1000, shuffle=False, num_workers=1, collate_fn=lambda batch: [r[0] for r in batch]
)
return next(iter(testloader))


def run_benchmark(batch):
results = []
for device in ["cpu", "cuda"]:
batch_device = [t.to(device=device) for t in batch]
for size in [1, 100, 1000]:
for num_threads in [1, 12, 24]:
for stmt, strat in zip(
[
"[torchvision.io.encode_jpeg(img) for img in batch_input]",
"torchvision.io.encode_jpeg(batch_input)",
],
["unfused", "fused"],
):
batch_input = batch_device[:size]
t = benchmark.Timer(
stmt=stmt,
setup="import torchvision",
globals={"batch_input": batch_input},
label="Image Encoding",
sub_label=f"{device.upper()} ({strat}): {stmt}",
description=f"{size} images",
num_threads=num_threads,
)
results.append(t.blocked_autorange())
compare = benchmark.Compare(results)
compare.print()


if __name__ == "__main__":
print_machine_specs()
batch = get_data()
mean_h, mean_w = statistics.mean(t.shape[-2] for t in batch), statistics.mean(t.shape[-1] for t in batch)
print(f"\nMean image size: {int(mean_h)}x{int(mean_w)}")
run_benchmark(batch)
197 changes: 196 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import concurrent.futures
import glob
import io
import os
Expand All @@ -10,7 +11,7 @@
import requests
import torch
import torchvision.transforms.functional as F
from common_utils import assert_equal, IN_OSS_CI, needs_cuda
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
from torchvision.io.image import (
_read_png_16,
Expand Down Expand Up @@ -508,6 +509,200 @@ def test_encode_jpeg(img_path, scripted):
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)


@needs_cuda
def test_encode_jpeg_cuda_device_param():
path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)

data = read_image(path)

current_device = torch.cuda.current_device()
current_stream = torch.cuda.current_stream()
num_devices = torch.cuda.device_count()
devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)]
results = []
for device in devices:
print(f"python: device: {device}")
results.append(encode_jpeg(data.to(device=device)))
assert len(results) == len(devices)
for result in results:
assert torch.all(result.cpu() == results[0].cpu())

assert current_device == torch.cuda.current_device()
assert current_stream == torch.cuda.current_stream()


@needs_cuda
@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("contiguous", (False, True))
def test_encode_jpeg_cuda(img_path, scripted, contiguous):
decoded_image_tv = read_image(img_path)
encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg

if "cmyk" in img_path:
pytest.xfail("Encoding a CMYK jpeg isn't supported")
if decoded_image_tv.shape[0] == 1:
pytest.xfail("Decoding a grayscale jpeg isn't supported")
# For more detail as to why check out: https://github.com/NVIDIA/cuda-samples/issues/23#issuecomment-559283013
if contiguous:
decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.contiguous_format)[0]
else:
decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.channels_last)[0]
encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda(), quality=75)
decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu())

# the actual encoded bytestreams from libnvjpeg and libjpeg-turbo differ for the same quality
# instead, we re-decode the encoded image and compare to the original
abs_mean_diff = (decoded_jpeg_cuda_tv.float() - decoded_image_tv.float()).abs().mean().item()
assert abs_mean_diff < 3


@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scripted", (True, False))
@pytest.mark.parametrize("contiguous", (True, False))
def test_encode_jpegs_batch(scripted, contiguous, device):
if device == "cpu" and IS_MACOS:
pytest.skip("https://github.com/pytorch/vision/issues/8031")
decoded_images_tv = []
for jpeg_path in get_images(IMAGE_ROOT, ".jpg"):
if "cmyk" in jpeg_path:
continue
decoded_image = read_image(jpeg_path)
if decoded_image.shape[0] == 1:
continue
if contiguous:
decoded_image = decoded_image[None].contiguous(memory_format=torch.contiguous_format)[0]
else:
decoded_image = decoded_image[None].contiguous(memory_format=torch.channels_last)[0]
decoded_images_tv.append(decoded_image)

encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg

decoded_images_tv_device = [img.to(device=device) for img in decoded_images_tv]
encoded_jpegs_tv_device = encode_fn(decoded_images_tv_device, quality=75)
encoded_jpegs_tv_device = [decode_jpeg(img.cpu()) for img in encoded_jpegs_tv_device]

for original, encoded_decoded in zip(decoded_images_tv, encoded_jpegs_tv_device):
c, h, w = original.shape
abs_mean_diff = (original.float() - encoded_decoded.float()).abs().mean().item()
assert abs_mean_diff < 3

# test multithreaded decoding
# in the current version we prevent this by using a lock but we still want to test it
num_workers = 10
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(encode_fn, decoded_images_tv_device) for _ in range(num_workers)]
encoded_images_threaded = [future.result() for future in futures]
assert len(encoded_images_threaded) == num_workers
for encoded_images in encoded_images_threaded:
assert len(decoded_images_tv_device) == len(encoded_images)
for i, (encoded_image_cuda, decoded_image_tv) in enumerate(zip(encoded_images, decoded_images_tv_device)):
# make sure all the threads produce identical outputs
assert torch.all(encoded_image_cuda == encoded_images_threaded[0][i])

# make sure the outputs are identical or close enough to baseline
decoded_cuda_encoded_image = decode_jpeg(encoded_image_cuda.cpu())
assert decoded_cuda_encoded_image.shape == decoded_image_tv.shape
assert decoded_cuda_encoded_image.dtype == decoded_image_tv.dtype
assert (decoded_cuda_encoded_image.cpu().float() - decoded_image_tv.cpu().float()).abs().mean() < 3


@needs_cuda
def test_single_encode_jpeg_cuda_errors():
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"))

with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"))

with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
encode_jpeg(torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"))

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"))

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8, device="cuda"))


@needs_cuda
def test_batch_encode_jpegs_cuda_errors():
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"),
]
)

with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((100, 100), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(RuntimeError, match="Input tensor should be on CPU"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(
RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
]
)

if torch.cuda.device_count() >= 2:
with pytest.raises(
RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:0"),
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:1"),
]
)

with pytest.raises(ValueError, match="encode_jpeg requires at least one input tensor when a list is passed"):
encode_jpeg([])


@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
@pytest.mark.parametrize(
"img_path",
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "decode_jpeg_cuda.h"
#include "encode_decode_jpegs_cuda.h"

#include <ATen/ATen.h>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <torch/types.h>
#include "../image_read_mode.h"
#include "encode_jpegs_cuda.h"

namespace vision {
namespace image {
Expand All @@ -11,5 +12,9 @@ C10_EXPORT torch::Tensor decode_jpeg_cuda(
ImageReadMode mode,
torch::Device device);

C10_EXPORT std::vector<torch::Tensor> encode_jpegs_cuda(
const std::vector<torch::Tensor>& decoded_images,
const int64_t quality);

} // namespace image
} // namespace vision
Loading

0 comments on commit 143d078

Please sign in to comment.