Skip to content

Commit

Permalink
Fixes "make Image(some_pil_image)" from TODO list (pytorch#7217)
Browse files Browse the repository at this point in the history
Added tests
  • Loading branch information
vfdev-5 committed Feb 13, 2023
1 parent acabaf8 commit 10f42ee
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 0 deletions.
23 changes: 23 additions & 0 deletions test/test_prototype_datapoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
import torch

from PIL import Image
from torchvision.prototype import datapoints


Expand Down Expand Up @@ -130,3 +132,24 @@ def test_wrap_like():
assert type(label_new) is datapoints.Label
assert label_new.data_ptr() == output.data_ptr()
assert label_new.categories is label.categories


@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)])
def test_image_instance(data):
image = datapoints.Image(data)
assert isinstance(image, torch.Tensor)
assert image.ndim == 3 and image.shape[0] == 3


@pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)])
def test_mask_instance(data):
mask = datapoints.Mask(data)
assert isinstance(mask, torch.Tensor)
assert mask.ndim == 3 and mask.shape[0] == 1


@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]]])
def test_bbox_instance(data):
bboxes = datapoints.BoundingBox(data, format="XYXY", spatial_size=(32, 32))
assert isinstance(bboxes, torch.Tensor)
assert bboxes.ndim == 2 and bboxes.shape[1] == 4
5 changes: 5 additions & 0 deletions torchvision/prototype/datapoints/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def __new__(
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> Image:
if isinstance(data, PIL.Image.Image):
from ..transforms import functional as F

data = F.pil_to_tensor(data)

tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if tensor.ndim < 2:
raise ValueError
Expand Down
6 changes: 6 additions & 0 deletions torchvision/prototype/datapoints/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Any, List, Optional, Tuple, Union

import PIL.Image
import torch
from torchvision.transforms import InterpolationMode

Expand All @@ -21,6 +22,11 @@ def __new__(
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> Mask:
if isinstance(data, PIL.Image.Image):
from ..transforms import functional as F

data = F.pil_to_tensor(data)

tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor)

Expand Down

0 comments on commit 10f42ee

Please sign in to comment.