-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pil_to_tensor to functionals #2092
Changes from 6 commits
286e316
f7eb489
f90b3bc
9c2fd3b
08ab5ec
cb19ed4
38ad5f3
1fa91a8
0fefbcb
7662b23
75be7bb
123503a
eff1db0
266860a
610fc1e
b9cca77
1b10f77
fbf661c
598107f
d69048e
fa1084c
2cb7a4f
3d565fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,6 +82,45 @@ def to_tensor(pic): | |
return img | ||
|
||
|
||
def as_tensor(pic): | ||
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor of same type. | ||
|
||
See ``AsTensor`` for more details. | ||
|
||
Args: | ||
pic (PIL Image or numpy.ndarray): Image to be converted to tensor. | ||
|
||
Returns: | ||
Tensor: Converted image. | ||
""" | ||
if not(_is_pil_image(pic) or _is_numpy(pic)): | ||
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) | ||
|
||
if _is_numpy(pic) and not _is_numpy_image(pic): | ||
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) | ||
|
||
if isinstance(pic, np.ndarray): | ||
# handle numpy array | ||
if pic.ndim == 2: | ||
pic = pic[:, :, None] | ||
|
||
img = torch.as_tensor(pic.transpose((2, 0, 1))) | ||
return img | ||
|
||
if accimage is not None and isinstance(pic, accimage.Image): | ||
xksteven marked this conversation as resolved.
Show resolved
Hide resolved
|
||
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) | ||
pic.copyto(nppic) | ||
return torch.as_tensor(nppic) | ||
|
||
# handle PIL Image | ||
img = torch.as_tensor(np.asarray(pic)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line will still produce the same bug mentioned in #2194. Converting to numpy with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we will wait until PyTorch fixes this behavior in master, as making a copy would be fairly expensive. If the warnings are too annoying, an alternative would be to only use |
||
|
||
img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) | ||
# put it from HWC to CHW format | ||
img = img.permute((2, 0, 1)) | ||
return img | ||
|
||
|
||
def to_pil_image(pic, mode=None): | ||
"""Convert a tensor or an ndarray to PIL Image. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we indeed only consider that this function only supports the
PIL -> tensor
conversion, then maybe a better name would bepil_to_tensor
or something like that? Open to suggestionsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you take a second pass at your earliest convenience?
One of the tests is a little awkward in that ToPILImage converts FloatTensors to bytes.
The other thing was I'm unsure of the parameter name "swap_to_channelsfirst".
Let me know.