Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow decode_image to support paths #8624

Merged
merged 2 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/source/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ For encoding, JPEG (cpu and CUDA) and PNG are supported.
:toctree: generated/
:template: function.rst

read_image
decode_image
encode_jpeg
decode_jpeg
Expand All @@ -38,6 +37,13 @@ For encoding, JPEG (cpu and CUDA) and PNG are supported.

ImageReadMode

Obsolete decoding function:

.. autosummary::
:toctree: generated/
:template: class.rst

read_image


Video
Expand Down
16 changes: 8 additions & 8 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,10 @@ Here is an example of how to use the pre-trained image classification models:

.. code:: python

from torchvision.io import read_image
from torchvision.io import decode_image
from torchvision.models import resnet50, ResNet50_Weights

img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
Expand Down Expand Up @@ -283,10 +283,10 @@ Here is an example of how to use the pre-trained quantized image classification

.. code:: python

from torchvision.io import read_image
from torchvision.io import decode_image
from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights

img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = ResNet50_QuantizedWeights.DEFAULT
Expand Down Expand Up @@ -339,11 +339,11 @@ Here is an example of how to use the pre-trained semantic segmentation models:

.. code:: python

from torchvision.io.image import read_image
from torchvision.io.image import decode_image
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
from torchvision.transforms.functional import to_pil_image

img = read_image("gallery/assets/dog1.jpg")
img = decode_image("gallery/assets/dog1.jpg")

# Step 1: Initialize model with the best available weights
weights = FCN_ResNet50_Weights.DEFAULT
Expand Down Expand Up @@ -411,12 +411,12 @@ Here is an example of how to use the pre-trained object detection models:
.. code:: python


from torchvision.io.image import read_image
from torchvision.io.image import decode_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
Expand Down
10 changes: 5 additions & 5 deletions gallery/others/plot_repurposing_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ def show(imgs):
# We will take images and masks from the `PenFudan Dataset <https://www.cis.upenn.edu/~jshi/ped_html/>`_.


from torchvision.io import read_image
from torchvision.io import decode_image

img_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054.png")
mask_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054_mask.png")
img = read_image(img_path)
mask = read_image(mask_path)
img = decode_image(img_path)
mask = decode_image(mask_path)


# %%
Expand Down Expand Up @@ -181,8 +181,8 @@ def __getitem__(self, idx):
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])

img = read_image(img_path)
mask = read_image(mask_path)
img = decode_image(img_path)
mask = decode_image(mask_path)

img = F.convert_image_dtype(img, dtype=torch.float)
mask = F.convert_image_dtype(mask, dtype=torch.float)
Expand Down
6 changes: 3 additions & 3 deletions gallery/others/plot_scripted_tensor_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch.nn as nn

import torchvision.transforms as v1
from torchvision.io import read_image
from torchvision.io import decode_image

plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)
Expand All @@ -39,8 +39,8 @@
# :class:`torch.nn.Sequential` instead of
# :class:`~torchvision.transforms.v2.Compose`:

dog1 = read_image(str(ASSETS_PATH / 'dog1.jpg'))
dog2 = read_image(str(ASSETS_PATH / 'dog2.jpg'))
dog1 = decode_image(str(ASSETS_PATH / 'dog1.jpg'))
dog2 = decode_image(str(ASSETS_PATH / 'dog2.jpg'))

transforms = torch.nn.Sequential(
v1.RandomCrop(224),
Expand Down
10 changes: 5 additions & 5 deletions gallery/others/plot_visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def show(imgs):
# image of dtype ``uint8`` as input.

from torchvision.utils import make_grid
from torchvision.io import read_image
from torchvision.io import decode_image
from pathlib import Path

dog1_int = read_image(str(Path('../assets') / 'dog1.jpg'))
dog2_int = read_image(str(Path('../assets') / 'dog2.jpg'))
dog1_int = decode_image(str(Path('../assets') / 'dog1.jpg'))
dog2_int = decode_image(str(Path('../assets') / 'dog2.jpg'))
dog_list = [dog1_int, dog2_int]

grid = make_grid(dog_list)
Expand Down Expand Up @@ -362,9 +362,9 @@ def show(imgs):
#

from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
from torchvision.io import read_image
from torchvision.io import decode_image

person_int = read_image(str(Path("../assets") / "person1.jpg"))
person_int = decode_image(str(Path("../assets") / "person1.jpg"))

weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
transforms = weights.transforms()
Expand Down
4 changes: 2 additions & 2 deletions gallery/transforms/plot_transforms_getting_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
plt.rcParams["savefig.bbox"] = 'tight'

from torchvision.transforms import v2
from torchvision.io import read_image
from torchvision.io import decode_image

torch.manual_seed(1)

# If you're trying to run that on Colab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
from helpers import plot
img = read_image(str(Path('../assets') / 'astronaut.jpg'))
img = decode_image(str(Path('../assets') / 'astronaut.jpg'))
print(f"{type(img) = }, {img.dtype = }, {img.shape = }")

# %%
Expand Down
10 changes: 5 additions & 5 deletions test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
import torchvision
from torchvision.io import decode_jpeg, decode_webp, read_file, read_image
from torchvision.io import decode_image, decode_jpeg, decode_webp, read_file
from torchvision.models import resnet50, ResNet50_Weights


Expand All @@ -21,13 +21,13 @@ def smoke_test_torchvision() -> None:


def smoke_test_torchvision_read_decode() -> None:
img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
img_jpg = decode_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
if img_jpg.shape != (3, 606, 517):
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
img_png = decode_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
if img_png.shape != (4, 471, 354):
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
img_webp = read_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp"))
img_webp = decode_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp"))
if img_webp.shape != (3, 100, 100):
raise RuntimeError(f"Unexpected shape of img_webp: {img_webp.shape}")

Expand All @@ -54,7 +54,7 @@ def smoke_test_compile() -> None:


def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
img = decode_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)

# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
Expand Down
21 changes: 21 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,5 +1044,26 @@ def test_decode_heic(decode_fun, scripted):
img += 123 # make sure image buffer wasn't freed by underlying decoding lib


@pytest.mark.parametrize("input_type", ("Path", "str", "tensor"))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_image_path(input_type, scripted):
# Check that decode_image can support not just tensors as input
path = next(get_images(IMAGE_ROOT, ".jpg"))
if input_type == "Path":
input = Path(path)
elif input_type == "str":
input = path
elif input_type == "tensor":
input = read_file(path)
else:
raise ValueError("Oops")

if scripted and input_type == "Path":
pytest.xfail(reason="Can't pass a Path when scripting")

decode_fun = torch.jit.script(decode_image) if scripted else decode_image
decode_fun(input)


if __name__ == "__main__":
pytest.main([__file__])
40 changes: 10 additions & 30 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@


def decode_image(
input: torch.Tensor,
input: Union[torch.Tensor, str],

Check warning on line 280 in torchvision/io/image.py

View workflow job for this annotation

GitHub Actions / bc

Function decode_image: input changed from torch.Tensor to Union[torch.Tensor, str]
mode: ImageReadMode = ImageReadMode.UNCHANGED,
apply_exif_orientation: bool = False,
) -> torch.Tensor:
"""
Detect whether an image is a JPEG, PNG, WEBP, or GIF and performs the
appropriate operation to decode the image into a Tensor.
"""Decode an image into a tensor.

Currently supported image formats are jpeg, png, gif and webp.

The values of the output tensor are in uint8 in [0, 255] for most cases.

Expand All @@ -295,8 +295,9 @@
tensor.

Args:
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
image.
input (Tensor or str or ``pathlib.Path``): The image to decode. If a
tensor is passed, it must be one dimensional uint8 tensor containing
the raw bytes of the image. Otherwise, this must be a path to the image file.
mode (ImageReadMode): the read mode used for optionally converting the image.
Default: ``ImageReadMode.UNCHANGED``.
See ``ImageReadMode`` class for more information on various
Expand All @@ -309,6 +310,8 @@
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_image)
if not isinstance(input, torch.Tensor):
input = read_file(str(input))
output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
return output

Expand All @@ -318,30 +321,7 @@
mode: ImageReadMode = ImageReadMode.UNCHANGED,
apply_exif_orientation: bool = False,
) -> torch.Tensor:
"""
Reads a JPEG, PNG, WEBP, or GIF image into a Tensor.

The values of the output tensor are in uint8 in [0, 255] for most cases.

If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
(supported from torchvision ``0.21``. Since uint16 support is limited in
pytorch, we recommend calling
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
after this function to convert the decoded image into a uint8 or float
tensor.

Args:
path (str or ``pathlib.Path``): path of the image.
mode (ImageReadMode): the read mode used for optionally converting the image.
Default: ``ImageReadMode.UNCHANGED``.
See ``ImageReadMode`` class for more information on various
available modes. Only applies to JPEG and PNG images.
apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
Only applies to JPEG and PNG images. Default: False.

Returns:
output (Tensor[image_channels, image_height, image_width])
"""
"""[OBSOLETE] Use :func:`~torchvision.io.decode_image` instead."""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_image)
data = read_file(path)
Expand Down
Loading