Skip to content

Adds normalize parameter to ToTensor operation #2060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ def _is_numpy_image(img):
return img.ndim in {2, 3}


def to_tensor(pic):
def to_tensor(pic, normalize=True):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

See ``ToTensor`` for more details.

Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
normalize (bool): whether the output tensor should be normalized.

Returns:
Tensor: Converted image.
Expand All @@ -52,7 +53,10 @@ def to_tensor(pic):
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backward compatibility
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
if normalize:
return img.float().div(255)
else:
return img.float()
else:
return img

Expand All @@ -77,7 +81,10 @@ def to_tensor(pic):
# put it from HWC to CHW format
img = img.permute((2, 0, 1)).contiguous()
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
if normalize:
return img.float().div(255)
else:
return img.float()
else:
return img

Expand Down
8 changes: 5 additions & 3 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,24 @@ class ToTensor(object):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
[0, 255] to a torch.FloatTensor of shape (C x H x W)
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
and normalizes the output in the range [0.0, 1.0] if normalize is True

In the other cases, tensors are returned without scaling.
"""

def __call__(self, pic):
def __call__(self, pic, normalize=True):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
normalize (bool): Whether to normalize the resulting output tensor.

Returns:
Tensor: Converted image.
"""
return F.to_tensor(pic)
return F.to_tensor(pic, normalize=normalize)

def __repr__(self):
return self.__class__.__name__ + '()'
Expand Down