From d615dbe00a48551c249e32c66128becde40c66c7 Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Fri, 3 Apr 2020 09:08:30 -0400 Subject: [PATCH 1/3] Adds normalize parameter to totensor Adds normalize parameter to totensor defaults to true for backwards compatibility. --- torchvision/transforms/functional.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7ce1fb6ab36..3a593c8ff7e 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -27,13 +27,14 @@ def _is_numpy_image(img): return img.ndim in {2, 3} -def to_tensor(pic): +def to_tensor(pic, normalize=True): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. See ``ToTensor`` for more details. Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + normalize (bool): whether the output tensor should be normalized. Returns: Tensor: Converted image. @@ -52,7 +53,10 @@ def to_tensor(pic): img = torch.from_numpy(pic.transpose((2, 0, 1))) # backward compatibility if isinstance(img, torch.ByteTensor): - return img.float().div(255) + if normalize: + return img.float().div(255) + else: + return img.float() else: return img @@ -77,7 +81,10 @@ def to_tensor(pic): # put it from HWC to CHW format img = img.permute((2, 0, 1)).contiguous() if isinstance(img, torch.ByteTensor): - return img.float().div(255) + if normalize: + return img.float().div(255) + else: + return img.float() else: return img From 0406116fc1687cb96960704f525ce50b6ceca8cf Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Fri, 3 Apr 2020 09:11:53 -0400 Subject: [PATCH 2/3] Adds normalize parameter to ToTensor. Adds normalize parameter to ToTensor defaults to true for backwards compatibility. --- torchvision/transforms/transforms.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 10783c8e53d..035856bc57b 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -74,22 +74,24 @@ class ToTensor(object): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. Converts a PIL Image or numpy.ndarray (H x W x C) in the range - [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] + [0, 255] to a torch.FloatTensor of shape (C x H x W) + and normalizes the output in the range [0.0, 1.0] if normalize is True if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8 In the other cases, tensors are returned without scaling. """ - def __call__(self, pic): + def __call__(self, pic, normalize=True): """ Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + normalize (bool): Whether to normalize the resulting output tensor. Returns: Tensor: Converted image. """ - return F.to_tensor(pic) + return F.to_tensor(pic, normalize=normalize) def __repr__(self): return self.__class__.__name__ + '()' From cd6e66f97d20a4d0cbbb96227431d186206b379a Mon Sep 17 00:00:00 2001 From: Steven Basart Date: Fri, 3 Apr 2020 09:22:10 -0400 Subject: [PATCH 3/3] Fixes pylint issues and rewords the documentation --- torchvision/transforms/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 035856bc57b..25dc450c9a9 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -74,10 +74,10 @@ class ToTensor(object): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. Converts a PIL Image or numpy.ndarray (H x W x C) in the range - [0, 255] to a torch.FloatTensor of shape (C x H x W) - and normalizes the output in the range [0.0, 1.0] if normalize is True + [0, 255] to a torch.FloatTensor of shape (C x H x W) if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8 + and normalizes the output in the range [0.0, 1.0] if normalize is True In the other cases, tensors are returned without scaling. """