Skip to content

Commit

Permalink
Adds ImageToTensor module and resize for non-batched images (kornia#978)
Browse files Browse the repository at this point in the history
* adapt resize for non batch

* implement ImageToTensor module

* handle length 2 and 3
  • Loading branch information
edgarriba authored Apr 26, 2021
1 parent 454f9cd commit 37b6007
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 2 deletions.
12 changes: 11 additions & 1 deletion kornia/geometry/transform/affwarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
warp_affine3d, get_projective_transform
)
from kornia.utils import _extract_device_dtype
from kornia.utils.image import _to_bchw

__all__ = [
"affine",
Expand Down Expand Up @@ -544,7 +545,16 @@ def resize(input: torch.Tensor, size: Union[int, Tuple[int, int]],
if size == input_size:
return input

return torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
# TODO: find a proper way to handle this cases in the future
input_tmp = _to_bchw(input)

output = torch.nn.functional.interpolate(
input_tmp, size=size, mode=interpolation, align_corners=align_corners)

if len(input.shape) != len(output.shape):
output = output.squeeze()

return output


def rescale(
Expand Down
3 changes: 2 additions & 1 deletion kornia/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .one_hot import one_hot
from .grid import create_meshgrid, create_meshgrid3d
from .image import tensor_to_image, image_to_tensor
from .image import tensor_to_image, image_to_tensor, ImageToTensor
from .pointcloud_io import save_pointcloud_ply, load_pointcloud_ply
from .draw import draw_rectangle
from .helpers import _extract_device_dtype
Expand All @@ -17,4 +17,5 @@
"load_pointcloud_ply",
"draw_rectangle",
"_extract_device_dtype",
"ImageToTensor",
]
16 changes: 16 additions & 0 deletions kornia/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import torch
import torch.nn as nn


def image_to_tensor(image: np.ndarray, keepdim: bool = True) -> torch.Tensor:
Expand Down Expand Up @@ -141,3 +142,18 @@ def tensor_to_image(tensor: torch.Tensor) -> np.ndarray:
"Cannot process tensor with shape {}".format(input_shape))

return image


class ImageToTensor(nn.Module):
"""Converts a numpy image to a PyTorch 4d tensor image.
Args:
keepdim (bool): If ``False`` unsqueeze the input image to match the shape
:math:`(B, H, W, C)`. Default: ``True``
"""
def __init__(self, keepdim: bool = False):
super().__init__()
self.keepdim = keepdim

def forward(self, x: np.ndarray) -> torch.Tensor:
return image_to_tensor(x, keepdim=self.keepdim)
5 changes: 5 additions & 0 deletions test/geometry/transform/test_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ def test_smoke(self, device, dtype):
out = kornia.resize(inp, (3, 4))
assert_allclose(inp, out, atol=1e-4, rtol=1e-4)

def test_no_batch(self, device, dtype):
inp = torch.rand(3, 3, 4, device=device, dtype=dtype)
out = kornia.resize(inp, (2, 5))
assert out.shape == (3, 2, 5)

def test_upsize(self, device, dtype):
inp = torch.rand(1, 3, 3, 4, device=device, dtype=dtype)
out = kornia.resize(inp, (6, 8))
Expand Down
5 changes: 5 additions & 0 deletions test/utils/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import numpy as np

import torch
from torch.testing import assert_allclose

import kornia as kornia
import kornia.testing as utils # test utils

Expand Down Expand Up @@ -43,6 +45,9 @@ def test_image_to_tensor(input_shape, expected):
assert tensor.shape == expected
assert isinstance(tensor, torch.Tensor)

to_tensor = kornia.utils.ImageToTensor(keepdim=False)
assert_allclose(tensor, to_tensor(image))


@pytest.mark.parametrize("input_shape, expected",
[((4, 4), (1, 4, 4)),
Expand Down

0 comments on commit 37b6007

Please sign in to comment.