Skip to content

Commit

Permalink
GPU jpeg decoder: add batch support and hardware decoding (#8496)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <nicolashug@fb.com>
  • Loading branch information
deekay42 and NicolasHug authored Aug 7, 2024
1 parent 5242d6a commit 0d80848
Show file tree
Hide file tree
Showing 10 changed files with 934 additions and 317 deletions.
67 changes: 0 additions & 67 deletions benchmarks/encoding.py

This file was deleted.

99 changes: 99 additions & 0 deletions benchmarks/encoding_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
import platform
import statistics

import torch
import torch.utils.benchmark as benchmark
import torchvision


def print_machine_specs():
print("Processor:", platform.processor())
print("Platform:", platform.platform())
print("Logical CPUs:", os.cpu_count())
print(f"\nCUDA device: {torch.cuda.get_device_name()}")
print(f"Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


def get_data():
transform = torchvision.transforms.Compose(
[
torchvision.transforms.PILToTensor(),
]
)
path = os.path.join(os.getcwd(), "data")
testset = torchvision.datasets.Places365(
root="./data", download=not os.path.exists(path), transform=transform, split="val"
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=1000, shuffle=False, num_workers=1, collate_fn=lambda batch: [r[0] for r in batch]
)
return next(iter(testloader))


def run_encoding_benchmark(decoded_images):
results = []
for device in ["cpu", "cuda"]:
decoded_images_device = [t.to(device=device) for t in decoded_images]
for size in [1, 100, 1000]:
for num_threads in [1, 12, 24]:
for stmt, strat in zip(
[
"[torchvision.io.encode_jpeg(img) for img in decoded_images_device_trunc]",
"torchvision.io.encode_jpeg(decoded_images_device_trunc)",
],
["unfused", "fused"],
):
decoded_images_device_trunc = decoded_images_device[:size]
t = benchmark.Timer(
stmt=stmt,
setup="import torchvision",
globals={"decoded_images_device_trunc": decoded_images_device_trunc},
label="Image Encoding",
sub_label=f"{device.upper()} ({strat}): {stmt}",
description=f"{size} images",
num_threads=num_threads,
)
results.append(t.blocked_autorange())
compare = benchmark.Compare(results)
compare.print()


def run_decoding_benchmark(encoded_images):
results = []
for device in ["cpu", "cuda"]:
for size in [1, 100, 1000]:
for num_threads in [1, 12, 24]:
for stmt, strat in zip(
[
f"[torchvision.io.decode_jpeg(img, device='{device}') for img in encoded_images_trunc]",
f"torchvision.io.decode_jpeg(encoded_images_trunc, device='{device}')",
],
["unfused", "fused"],
):
encoded_images_trunc = encoded_images[:size]
t = benchmark.Timer(
stmt=stmt,
setup="import torchvision",
globals={"encoded_images_trunc": encoded_images_trunc},
label="Image Decoding",
sub_label=f"{device.upper()} ({strat}): {stmt}",
description=f"{size} images",
num_threads=num_threads,
)
results.append(t.blocked_autorange())
compare = benchmark.Compare(results)
compare.print()


if __name__ == "__main__":
print_machine_specs()
decoded_images = get_data()
mean_h, mean_w = statistics.mean(t.shape[-2] for t in decoded_images), statistics.mean(
t.shape[-1] for t in decoded_images
)
print(f"\nMean image size: {int(mean_h)}x{int(mean_w)}")
run_encoding_benchmark(decoded_images)
encoded_images_cuda = torchvision.io.encode_jpeg([img.cuda() for img in decoded_images])
encoded_images_cpu = [img.cpu() for img in encoded_images_cuda]
run_decoding_benchmark(encoded_images_cpu)
121 changes: 99 additions & 22 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,23 +413,32 @@ def test_read_interlaced_png():


@needs_cuda
@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_jpeg_cuda(mode, img_path, scripted):
if "cmyk" in img_path:
pytest.xfail("Decoding a CMYK jpeg isn't supported")
def test_decode_jpegs_cuda(mode, scripted):
encoded_images = []
for jpeg_path in get_images(IMAGE_ROOT, ".jpg"):
if "cmyk" in jpeg_path:
continue
encoded_image = read_file(jpeg_path)
encoded_images.append(encoded_image)
decoded_images_cpu = decode_jpeg(encoded_images, mode=mode)
decode_fn = torch.jit.script(decode_jpeg) if scripted else decode_jpeg

data = read_file(img_path)
img = decode_image(data, mode=mode)
f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
img_nvjpeg = f(data, mode=mode, device="cuda")
# test multithreaded decoding
# in the current version we prevent this by using a lock but we still want to test it
num_workers = 10

# Some difference expected between jpeg implementations
assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)]
decoded_images_threaded = [future.result() for future in futures]
assert len(decoded_images_threaded) == num_workers
for decoded_images in decoded_images_threaded:
assert len(decoded_images) == len(encoded_images)
for decoded_image_cuda, decoded_image_cpu in zip(decoded_images, decoded_images_cpu):
assert decoded_image_cuda.shape == decoded_image_cpu.shape
assert decoded_image_cuda.dtype == decoded_image_cpu.dtype == torch.uint8
assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2


@needs_cuda
Expand All @@ -440,25 +449,95 @@ def test_decode_image_cuda_raises():


@needs_cuda
@pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda")))
def test_decode_jpeg_cuda_device_param(cuda_device):
"""Make sure we can pass a string or a torch.device as device param"""
def test_decode_jpeg_cuda_device_param():
path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
data = read_file(path)
decode_jpeg(data, device=cuda_device)
current_device = torch.cuda.current_device()
current_stream = torch.cuda.current_stream()
num_devices = torch.cuda.device_count()
devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)]
results = []
for device in devices:
results.append(decode_jpeg(data, device=device))
assert len(results) == len(devices)
for result in results:
assert torch.all(result.cpu() == results[0].cpu())
assert current_device == torch.cuda.current_device()
assert current_stream == torch.cuda.current_stream()


@needs_cuda
def test_decode_jpeg_cuda_errors():
data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_jpeg(data.reshape(-1, 1), device="cuda")
with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
with pytest.raises(ValueError, match="must be tensors"):
decode_jpeg([1, 2, 3])
with pytest.raises(ValueError, match="Input tensor must be a CPU tensor"):
decode_jpeg(data.to("cuda"), device="cuda")
with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
decode_jpeg(data.to(torch.float), device="cuda")
with pytest.raises(RuntimeError, match="Expected a cuda device"):
torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, "cpu")
with pytest.raises(RuntimeError, match="Expected the device parameter to be a cuda device"):
torch.ops.image.decode_jpegs_cuda([data], ImageReadMode.UNCHANGED.value, "cpu")
with pytest.raises(ValueError, match="Input tensor must be a CPU tensor"):
decode_jpeg(
torch.empty((100,), dtype=torch.uint8, device="cuda"),
)
with pytest.raises(ValueError, match="Input list must contain tensors on CPU"):
decode_jpeg(
[
torch.empty((100,), dtype=torch.uint8, device="cuda"),
torch.empty((100,), dtype=torch.uint8, device="cuda"),
]
)

with pytest.raises(ValueError, match="Input list must contain tensors on CPU"):
decode_jpeg(
[
torch.empty((100,), dtype=torch.uint8, device="cuda"),
torch.empty((100,), dtype=torch.uint8, device="cuda"),
],
device="cuda",
)

with pytest.raises(ValueError, match="Input list must contain tensors on CPU"):
decode_jpeg(
[
torch.empty((100,), dtype=torch.uint8, device="cpu"),
torch.empty((100,), dtype=torch.uint8, device="cuda"),
],
device="cuda",
)

with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
decode_jpeg(
[
torch.empty((100,), dtype=torch.uint8),
torch.empty((100,), dtype=torch.float32),
],
device="cuda",
)

with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_jpeg(
[
torch.empty((100,), dtype=torch.uint8),
torch.empty((1, 100), dtype=torch.uint8),
],
device="cuda",
)

with pytest.raises(RuntimeError, match="Error while decoding JPEG images"):
decode_jpeg(
[
torch.empty((100,), dtype=torch.uint8),
torch.empty((100,), dtype=torch.uint8),
],
device="cuda",
)

with pytest.raises(ValueError, match="Input list must contain at least one element"):
decode_jpeg([], device="cuda")


def test_encode_jpeg_errors():
Expand Down Expand Up @@ -515,12 +594,10 @@ def test_encode_jpeg_cuda_device_param():
devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)]
results = []
for device in devices:
print(f"python: device: {device}")
results.append(encode_jpeg(data.to(device=device)))
assert len(results) == len(devices)
for result in results:
assert torch.all(result.cpu() == results[0].cpu())

assert current_device == torch.cuda.current_device()
assert current_stream == torch.cuda.current_stream()

Expand Down
Loading

0 comments on commit 0d80848

Please sign in to comment.