diff --git a/benchmarks/encoding.py b/benchmarks/encoding.py new file mode 100644 index 00000000000..f994b03c783 --- /dev/null +++ b/benchmarks/encoding.py @@ -0,0 +1,67 @@ +import os +import platform +import statistics + +import torch +import torch.utils.benchmark as benchmark +import torchvision + + +def print_machine_specs(): + print("Processor:", platform.processor()) + print("Platform:", platform.platform()) + print("Logical CPUs:", os.cpu_count()) + print(f"\nCUDA device: {torch.cuda.get_device_name()}") + print(f"Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") + + +def get_data(): + transform = torchvision.transforms.Compose( + [ + torchvision.transforms.PILToTensor(), + ] + ) + path = os.path.join(os.getcwd(), "data") + testset = torchvision.datasets.Places365( + root="./data", download=not os.path.exists(path), transform=transform, split="val" + ) + testloader = torch.utils.data.DataLoader( + testset, batch_size=1000, shuffle=False, num_workers=1, collate_fn=lambda batch: [r[0] for r in batch] + ) + return next(iter(testloader)) + + +def run_benchmark(batch): + results = [] + for device in ["cpu", "cuda"]: + batch_device = [t.to(device=device) for t in batch] + for size in [1, 100, 1000]: + for num_threads in [1, 12, 24]: + for stmt, strat in zip( + [ + "[torchvision.io.encode_jpeg(img) for img in batch_input]", + "torchvision.io.encode_jpeg(batch_input)", + ], + ["unfused", "fused"], + ): + batch_input = batch_device[:size] + t = benchmark.Timer( + stmt=stmt, + setup="import torchvision", + globals={"batch_input": batch_input}, + label="Image Encoding", + sub_label=f"{device.upper()} ({strat}): {stmt}", + description=f"{size} images", + num_threads=num_threads, + ) + results.append(t.blocked_autorange()) + compare = benchmark.Compare(results) + compare.print() + + +if __name__ == "__main__": + print_machine_specs() + batch = get_data() + mean_h, mean_w = statistics.mean(t.shape[-2] for t in batch), statistics.mean(t.shape[-1] for t in batch) + print(f"\nMean image size: {int(mean_h)}x{int(mean_w)}") + run_benchmark(batch) diff --git a/test/test_image.py b/test/test_image.py index 222322a8293..86018dccc42 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -1,3 +1,4 @@ +import concurrent.futures import glob import io import os @@ -10,7 +11,7 @@ import requests import torch import torchvision.transforms.functional as F -from common_utils import assert_equal, IN_OSS_CI, needs_cuda +from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence from torchvision.io.image import ( _read_png_16, @@ -508,6 +509,200 @@ def test_encode_jpeg(img_path, scripted): assert_equal(encoded_jpeg_torch, encoded_jpeg_pil) +@needs_cuda +def test_encode_jpeg_cuda_device_param(): + path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path) + + data = read_image(path) + + current_device = torch.cuda.current_device() + current_stream = torch.cuda.current_stream() + num_devices = torch.cuda.device_count() + devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)] + results = [] + for device in devices: + print(f"python: device: {device}") + results.append(encode_jpeg(data.to(device=device))) + assert len(results) == len(devices) + for result in results: + assert torch.all(result.cpu() == results[0].cpu()) + + assert current_device == torch.cuda.current_device() + assert current_stream == torch.cuda.current_stream() + + +@needs_cuda +@pytest.mark.parametrize( + "img_path", + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")], +) +@pytest.mark.parametrize("scripted", (False, True)) +@pytest.mark.parametrize("contiguous", (False, True)) +def test_encode_jpeg_cuda(img_path, scripted, contiguous): + decoded_image_tv = read_image(img_path) + encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg + + if "cmyk" in img_path: + pytest.xfail("Encoding a CMYK jpeg isn't supported") + if decoded_image_tv.shape[0] == 1: + pytest.xfail("Decoding a grayscale jpeg isn't supported") + # For more detail as to why check out: https://github.com/NVIDIA/cuda-samples/issues/23#issuecomment-559283013 + if contiguous: + decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.contiguous_format)[0] + else: + decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.channels_last)[0] + encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda(), quality=75) + decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu()) + + # the actual encoded bytestreams from libnvjpeg and libjpeg-turbo differ for the same quality + # instead, we re-decode the encoded image and compare to the original + abs_mean_diff = (decoded_jpeg_cuda_tv.float() - decoded_image_tv.float()).abs().mean().item() + assert abs_mean_diff < 3 + + +@pytest.mark.parametrize("device", cpu_and_cuda()) +@pytest.mark.parametrize("scripted", (True, False)) +@pytest.mark.parametrize("contiguous", (True, False)) +def test_encode_jpegs_batch(scripted, contiguous, device): + if device == "cpu" and IS_MACOS: + pytest.skip("https://github.com/pytorch/vision/issues/8031") + decoded_images_tv = [] + for jpeg_path in get_images(IMAGE_ROOT, ".jpg"): + if "cmyk" in jpeg_path: + continue + decoded_image = read_image(jpeg_path) + if decoded_image.shape[0] == 1: + continue + if contiguous: + decoded_image = decoded_image[None].contiguous(memory_format=torch.contiguous_format)[0] + else: + decoded_image = decoded_image[None].contiguous(memory_format=torch.channels_last)[0] + decoded_images_tv.append(decoded_image) + + encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg + + decoded_images_tv_device = [img.to(device=device) for img in decoded_images_tv] + encoded_jpegs_tv_device = encode_fn(decoded_images_tv_device, quality=75) + encoded_jpegs_tv_device = [decode_jpeg(img.cpu()) for img in encoded_jpegs_tv_device] + + for original, encoded_decoded in zip(decoded_images_tv, encoded_jpegs_tv_device): + c, h, w = original.shape + abs_mean_diff = (original.float() - encoded_decoded.float()).abs().mean().item() + assert abs_mean_diff < 3 + + # test multithreaded decoding + # in the current version we prevent this by using a lock but we still want to test it + num_workers = 10 + with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(encode_fn, decoded_images_tv_device) for _ in range(num_workers)] + encoded_images_threaded = [future.result() for future in futures] + assert len(encoded_images_threaded) == num_workers + for encoded_images in encoded_images_threaded: + assert len(decoded_images_tv_device) == len(encoded_images) + for i, (encoded_image_cuda, decoded_image_tv) in enumerate(zip(encoded_images, decoded_images_tv_device)): + # make sure all the threads produce identical outputs + assert torch.all(encoded_image_cuda == encoded_images_threaded[0][i]) + + # make sure the outputs are identical or close enough to baseline + decoded_cuda_encoded_image = decode_jpeg(encoded_image_cuda.cpu()) + assert decoded_cuda_encoded_image.shape == decoded_image_tv.shape + assert decoded_cuda_encoded_image.dtype == decoded_image_tv.dtype + assert (decoded_cuda_encoded_image.cpu().float() - decoded_image_tv.cpu().float()).abs().mean() < 3 + + +@needs_cuda +def test_single_encode_jpeg_cuda_errors(): + with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): + encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32, device="cuda")) + + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"): + encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda")) + + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"): + encode_jpeg(torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda")) + + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): + encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda")) + + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): + encode_jpeg(torch.empty((100, 100), dtype=torch.uint8, device="cuda")) + + +@needs_cuda +def test_batch_encode_jpegs_cuda_errors(): + with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"), + ] + ) + + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"), + ] + ) + + with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"), + ] + ) + + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"), + ] + ) + + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((100, 100), dtype=torch.uint8, device="cuda"), + ] + ) + + with pytest.raises(RuntimeError, match="Input tensor should be on CPU"): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"), + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + ] + ) + + with pytest.raises( + RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg" + ): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"), + torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"), + ] + ) + + if torch.cuda.device_count() >= 2: + with pytest.raises( + RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg" + ): + encode_jpeg( + [ + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:0"), + torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:1"), + ] + ) + + with pytest.raises(ValueError, match="encode_jpeg requires at least one input tensor when a list is passed"): + encode_jpeg([]) + + @pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031") @pytest.mark.parametrize( "img_path", diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index ee7d432f30d..26fecc3e1f3 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -1,4 +1,4 @@ -#include "decode_jpeg_cuda.h" +#include "encode_decode_jpegs_cuda.h" #include diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h b/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h similarity index 62% rename from torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h rename to torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h index 496b355e9b7..7723d11d621 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h +++ b/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h @@ -2,6 +2,7 @@ #include #include "../image_read_mode.h" +#include "encode_jpegs_cuda.h" namespace vision { namespace image { @@ -11,5 +12,9 @@ C10_EXPORT torch::Tensor decode_jpeg_cuda( ImageReadMode mode, torch::Device device); +C10_EXPORT std::vector encode_jpegs_cuda( + const std::vector& decoded_images, + const int64_t quality); + } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp new file mode 100644 index 00000000000..1f10327ddbf --- /dev/null +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp @@ -0,0 +1,274 @@ +#include "encode_jpegs_cuda.h" +#if !NVJPEG_FOUND +namespace vision { +namespace image { +std::vector encode_jpegs_cuda( + const std::vector& decoded_images, + const int64_t quality) { + TORCH_CHECK( + false, "encode_jpegs_cuda: torchvision not compiled with nvJPEG support"); +} +} // namespace image +} // namespace vision +#else + +#include +#include +#include +#include +#include +#include +#include +#include +#include "c10/core/ScalarType.h" + +namespace vision { +namespace image { + +// We use global variables to cache the encoder and decoder instances and +// reuse them across calls to the corresponding pytorch functions +std::mutex encoderMutex; +std::unique_ptr cudaJpegEncoder; + +std::vector encode_jpegs_cuda( + const std::vector& decoded_images, + const int64_t quality) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cuda.encode_jpegs_cuda.encode_jpegs_cuda"); + + // Some nvjpeg structures are not thread safe so we're keeping it single + // threaded for now. In the future this may be an opportunity to unlock + // further speedups + std::lock_guard lock(encoderMutex); + TORCH_CHECK(decoded_images.size() > 0, "Empty input tensor list"); + torch::Device device = decoded_images[0].device(); + at::cuda::CUDAGuard device_guard(device); + + // lazy init of the encoder class + // the encoder object holds on to a lot of state and is expensive to create, + // so we reuse it across calls. NB: the cached structures are device specific + // and cannot be reused across devices + if (cudaJpegEncoder == nullptr || device != cudaJpegEncoder->target_device) { + if (cudaJpegEncoder != nullptr) + delete cudaJpegEncoder.release(); + + cudaJpegEncoder = std::make_unique(device); + + // Unfortunately, we cannot rely on the smart pointer releasing the encoder + // object correctly upon program exit. This is because, when cudaJpegEncoder + // gets destroyed, the CUDA runtime may already be shut down, rendering all + // destroy* calls in the encoder destructor invalid. Instead, we use an + // atexit hook which executes after main() finishes, but hopefully before + // CUDA shuts down when the program exits. If CUDA is already shut down the + // destructor will detect this and will not attempt to destroy any encoder + // structures. + std::atexit([]() { delete cudaJpegEncoder.release(); }); + } + + std::vector contig_images; + contig_images.reserve(decoded_images.size()); + for (const auto& image : decoded_images) { + TORCH_CHECK( + image.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + + TORCH_CHECK( + image.device() == device, + "All input tensors must be on the same CUDA device when encoding with nvjpeg") + + TORCH_CHECK( + image.dim() == 3 && image.numel() > 0, + "Input data should be a 3-dimensional tensor"); + + TORCH_CHECK( + image.size(0) == 3, + "The number of channels should be 3, got: ", + image.size(0)); + + // nvjpeg requires images to be contiguous + if (image.is_contiguous()) { + contig_images.push_back(image); + } else { + contig_images.push_back(image.contiguous()); + } + } + + cudaJpegEncoder->set_quality(quality); + std::vector encoded_images; + at::cuda::CUDAEvent event; + event.record(cudaJpegEncoder->stream); + for (const auto& image : contig_images) { + auto encoded_image = cudaJpegEncoder->encode_jpeg(image); + encoded_images.push_back(encoded_image); + } + + // We use a dedicated stream to do the encoding and even though the results + // may be ready on that stream we cannot assume that they are also available + // on the current stream of the calling context when this function returns. We + // use a blocking event to ensure that this is indeed the case. Crucially, we + // do not want to block the host at this particular point + // (which is what cudaStreamSynchronize would do.) Events allow us to + // synchronize the streams without blocking the host. + event.block(at::cuda::getCurrentCUDAStream( + cudaJpegEncoder->original_device.has_index() + ? cudaJpegEncoder->original_device.index() + : 0)); + return encoded_images; +} + +CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device) + : original_device{torch::kCUDA, torch::cuda::current_device()}, + target_device{target_device}, + stream{ + target_device.has_index() + ? at::cuda::getStreamFromPool(false, target_device.index()) + : at::cuda::getStreamFromPool(false)} { + nvjpegStatus_t status; + status = nvjpegCreateSimple(&nvjpeg_handle); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg handle: ", + status); + + status = nvjpegEncoderStateCreate(nvjpeg_handle, &nv_enc_state, stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg encoder state: ", + status); + + status = nvjpegEncoderParamsCreate(nvjpeg_handle, &nv_enc_params, stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg encoder params: ", + status); +} + +CUDAJpegEncoder::~CUDAJpegEncoder() { + /* + The below code works on Mac and Linux, but fails on Windows. + This is because on Windows, the atexit hook which calls this + destructor executes after cuda is already shut down causing SIGSEGV. + We do not have a solution to this problem at the moment, so we'll + just leak the libnvjpeg & cuda variables for the time being and hope + that the CUDA runtime handles cleanup for us. + Please send a PR if you have a solution for this problem. + */ + + // // We run cudaGetDeviceCount as a dummy to test if the CUDA runtime is + // still + // // initialized. If it is not, we can skip the rest of this function as it + // is + // // unsafe to execute. + // int deviceCount = 0; + // cudaError_t error = cudaGetDeviceCount(&deviceCount); + // if (error != cudaSuccess) + // return; // CUDA runtime has already shut down. There's nothing we can do + // // now. + + // nvjpegStatus_t status; + + // status = nvjpegEncoderParamsDestroy(nv_enc_params); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy nvjpeg encoder params: ", + // status); + + // status = nvjpegEncoderStateDestroy(nv_enc_state); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy nvjpeg encoder state: ", + // status); + + // cudaStreamSynchronize(stream); + + // status = nvjpegDestroy(nvjpeg_handle); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); +} + +torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) { + int channels = src_image.size(0); + int height = src_image.size(1); + int width = src_image.size(2); + + nvjpegStatus_t status; + cudaError_t cudaStatus; + status = nvjpegEncoderParamsSetSamplingFactors( + nv_enc_params, NVJPEG_CSS_444, stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to set nvjpeg encoder params sampling factors: ", + status); + + nvjpegImage_t target_image; + for (int c = 0; c < channels; c++) { + target_image.channel[c] = src_image[c].data_ptr(); + // this is why we need contiguous tensors + target_image.pitch[c] = width; + } + for (int c = channels; c < NVJPEG_MAX_COMPONENT; c++) { + target_image.channel[c] = nullptr; + target_image.pitch[c] = 0; + } + // Encode the image + status = nvjpegEncodeImage( + nvjpeg_handle, + nv_enc_state, + nv_enc_params, + &target_image, + NVJPEG_INPUT_RGB, + width, + height, + stream); + + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, "image encoding failed: ", status); + // Retrieve length of the encoded image + size_t length; + status = nvjpegEncodeRetrieveBitstreamDevice( + nvjpeg_handle, nv_enc_state, NULL, &length, stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to retrieve encoded image stream state: ", + status); + + // Synchronize the stream to ensure that the encoded image is ready + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); + + // Reserve buffer for the encoded image + torch::Tensor encoded_image = torch::empty( + {static_cast(length)}, + torch::TensorOptions() + .dtype(torch::kByte) + .layout(torch::kStrided) + .device(target_device) + .requires_grad(false)); + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); + // Retrieve the encoded image + status = nvjpegEncodeRetrieveBitstreamDevice( + nvjpeg_handle, + nv_enc_state, + encoded_image.data_ptr(), + &length, + 0); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to retrieve encoded image: ", + status); + return encoded_image; +} + +void CUDAJpegEncoder::set_quality(const int64_t quality) { + nvjpegStatus_t paramsQualityStatus = + nvjpegEncoderParamsSetQuality(nv_enc_params, quality, stream); + TORCH_CHECK( + paramsQualityStatus == NVJPEG_STATUS_SUCCESS, + "Failed to set nvjpeg encoder params quality: ", + paramsQualityStatus); +} + +} // namespace image +} // namespace vision + +#endif // NVJPEG_FOUND diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h new file mode 100644 index 00000000000..543940f1585 --- /dev/null +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h @@ -0,0 +1,33 @@ +#pragma once +#include +#include +#if NVJPEG_FOUND + +#include +#include +#include + +namespace vision { +namespace image { + +class CUDAJpegEncoder { + public: + CUDAJpegEncoder(const torch::Device& device); + ~CUDAJpegEncoder(); + + torch::Tensor encode_jpeg(const torch::Tensor& src_image); + + void set_quality(const int64_t quality); + + const torch::Device original_device; + const torch::Device target_device; + const c10::cuda::CUDAStream stream; + + protected: + nvjpegEncoderState_t nv_enc_state; + nvjpegEncoderParams_t nv_enc_params; + nvjpegHandle_t nvjpeg_handle; +}; +} // namespace image +} // namespace vision +#endif diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index 2909938f0bd..68267b72604 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -27,6 +27,7 @@ static auto registry = .op("image::decode_image(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", &decode_image) .op("image::decode_jpeg_cuda", &decode_jpeg_cuda) + .op("image::encode_jpegs_cuda", &encode_jpegs_cuda) .op("image::_jpeg_version", &_jpeg_version) .op("image::_is_compiled_against_turbo", &_is_compiled_against_turbo); diff --git a/torchvision/csrc/io/image/image.h b/torchvision/csrc/io/image/image.h index 457b1548d4d..f7e9b63801c 100644 --- a/torchvision/csrc/io/image/image.h +++ b/torchvision/csrc/io/image/image.h @@ -7,4 +7,4 @@ #include "cpu/encode_jpeg.h" #include "cpu/encode_png.h" #include "cpu/read_write_file.h" -#include "cuda/decode_jpeg_cuda.h" +#include "cuda/encode_decode_jpegs_cuda.h" diff --git a/torchvision/io/image.py b/torchvision/io/image.py index e7a429880a2..df2fdef3580 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import List, Union from warnings import warn import torch @@ -68,7 +69,9 @@ def write_file(filename: str, data: torch.Tensor) -> None: def decode_png( - input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False + input: torch.Tensor, + mode: ImageReadMode = ImageReadMode.UNCHANGED, + apply_exif_orientation: bool = False, ) -> torch.Tensor: """ Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor. @@ -180,28 +183,42 @@ def decode_jpeg( return output -def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor: +def encode_jpeg( + input: Union[torch.Tensor, List[torch.Tensor]], quality: int = 75 +) -> Union[torch.Tensor, List[torch.Tensor]]: """ - Takes an input tensor in CHW layout and returns a buffer with the contents - of its corresponding JPEG file. + Takes a (list of) input tensor(s) in CHW layout and returns a (list of) buffer(s) with the contents + of the corresponding JPEG file(s). + + .. note:: + Passing a list of CUDA tensors is more efficient than repeated individual calls to ``encode_jpeg``. + For CPU tensors the performance is equivalent. Args: - input (Tensor[channels, image_height, image_width])): int8 image tensor of - ``c`` channels, where ``c`` must be 1 or 3. - quality (int): Quality of the resulting JPEG file, it must be a number between + input (Tensor[channels, image_height, image_width] or List[Tensor[channels, image_height, image_width]]): + (list of) uint8 image tensor(s) of ``c`` channels, where ``c`` must be 1 or 3 + quality (int): Quality of the resulting JPEG file(s). Must be a number between 1 and 100. Default: 75 Returns: - output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the - JPEG file. + output (Tensor[1] or list[Tensor[1]]): A (list of) one dimensional uint8 tensor(s) that contain the raw bytes of the JPEG file. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(encode_jpeg) if quality < 1 or quality > 100: raise ValueError("Image quality should be a positive number between 1 and 100") - - output = torch.ops.image.encode_jpeg(input, quality) - return output + if isinstance(input, list): + if not input: + raise ValueError("encode_jpeg requires at least one input tensor when a list is passed") + if input[0].device.type == "cuda": + return torch.ops.image.encode_jpegs_cuda(input, quality) + else: + return [torch.ops.image.encode_jpeg(image, quality) for image in input] + else: # single input tensor + if input.device.type == "cuda": + return torch.ops.image.encode_jpegs_cuda([input], quality)[0] + else: + return torch.ops.image.encode_jpeg(input, quality) def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): @@ -218,11 +235,14 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(write_jpeg) output = encode_jpeg(input, quality) + assert isinstance(output, torch.Tensor) # Needed for torchscript write_file(filename, output) def decode_image( - input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False + input: torch.Tensor, + mode: ImageReadMode = ImageReadMode.UNCHANGED, + apply_exif_orientation: bool = False, ) -> torch.Tensor: """ Detect whether an image is a JPEG, PNG or GIF and performs the appropriate @@ -251,7 +271,9 @@ def decode_image( def read_image( - path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False + path: str, + mode: ImageReadMode = ImageReadMode.UNCHANGED, + apply_exif_orientation: bool = False, ) -> torch.Tensor: """ Reads a JPEG, PNG or GIF image into a 3 dimensional RGB or grayscale Tensor. diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index eac27f37022..60b49099fc5 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -78,9 +78,14 @@ def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor: if image.shape[0] == 0: # degenerate return image.reshape(original_shape).clone() - image = [decode_jpeg(encode_jpeg(image[i], quality=quality)) for i in range(image.shape[0])] - image = torch.stack(image, dim=0).view(original_shape) - return image + images = [] + for i in range(image.shape[0]): + encoded_image = encode_jpeg(image[i], quality=quality) + assert isinstance(encoded_image, torch.Tensor) # For torchscript + images.append(decode_jpeg(encoded_image)) + + images = torch.stack(images, dim=0).view(original_shape) + return images @_register_kernel_internal(jpeg, tv_tensors.Video)