-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pil_to_tensor to functionals #2092
Conversation
Similar functionality to to_tensor without the default conversion to float and division by 255. Also adds support for Image mode 'L'.
Adds tests to AsTensor and removes the conversion to float and division by 255.
Calls the as_tensor function in functionals and adds the function AsTensor as callable from transforms.
This was handled by the else condition previously so I'll remove it.
Adds two line breaks between functions to fix lint issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, it seems that we always convert to numpy
first and then import it. I think it would be much clearer if we did something like this:
if isinstance(pic, PIL_Image):
pic = pil_to_numpy(pic)
elif isinstance(pic, accimage_Image):
pic = accimage_to_numpy(pic)
return numpy_to_torchvision(pic)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, it seems that we always convert to
numpy
first and then import it. I think it would be much clearer if we did something like this:if isinstance(pic, PIL_Image): pic = pil_to_numpy(pic) elif isinstance(pic, accimage_Image): pic = accimage_to_numpy(pic) return numpy_to_torchvision(pic)
Are you suggesting to have a helper function for pil_to_numpy(pic)
instead of including it in the same function as_tensor(pic)
?
Should they be callable or should it be _pil_to_numpy(pic)
so as not to expose it to public API? I think I'd prefer the latter personally.
No strong opinion here. My point is that as is (and as it was before) the function is quite confusing. First, we handle I think it would be much clearer to convert
I suggest you do them nested within |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
I did a first pass, let me know what you think
torchvision/transforms/functional.py
Outdated
|
||
img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) | ||
# put it from HWC to CHW format | ||
img = img.permute((2, 0, 1)).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm unsure if we want to call contiguous()
here.
If fact, I was thinking about letting the tensor be with a different memory format (channels_last, HWC).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain your rationale for the HWC memory format?
Almost all of the downstream operations expect the CHW format so should there be a separate function that handles this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @xksteven here. The only reason I see for changing the format is if someone just wants the import and want to squeeze every milli- / microsecond he can get. If that is the intention, I suggest we add a channels_first
flag that defaults to True
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I miscommunicated what my intentions were, sorry about that.
What I wanted to say was that images are naturally stored as HWC, while all PyTorch operations expect CHW (up to now). But there is an ongoing effort on PyTorch to add support for channels_last
, which takes tensors of shape CHW but with strides such that is just a transposed HWC
(no contiguous call).
Given that all downstream operations in torchvision should support arbitrarily-strided tensors, I would vote for returning non-contiguous tensors, so that PyTorch, in the future when dedicated kernel support for channels_last is implemented) we will be able to handle those in an efficient manner.
torchvision/transforms/functional.py
Outdated
if isinstance(pic, np.ndarray): | ||
# handle numpy array | ||
if pic.ndim == 2: | ||
pic = pic[:, :, None] | ||
|
||
img = torch.from_numpy(pic.transpose((2, 0, 1))) | ||
return img |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we should even support handling np.ndarray
in this function.
Indeed, the data in a np.ndarray
can have any format (for example, it can be a float array with range from 0-255), and we can't properly handle all possible cases. It's the responsibility of the user to do it. Plus, if the user passes OpenCV arrays to the function, it will be in BGR
format (different from RGB from Pillow and what we use in torchvision)
As such, I think that we should probably only handle PIL Images -- handling numpy arrays is trivial from the user perspective (torch.as_tensor(ndarray)
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not really for or against keeping the numpy conversion here.
With that said I think the primary purpose of this function is doing the conversion to a pytorch tensor format and making it into a channels first format. The scope of the inputs can therefore be broadened or narrowed without really affecting the goals of the function.
The OpenCV arrays will still come out to be channels first after being passed through this function. We do not need to make any other assumptions other than the data format is HWC (or in the case of black&white images that it is HW).
So let me know if you think it's best to drop numpy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what we should aim for is the least amount of potential user-errors or surprises.
Indeed, both OpenCV and scikit-image
returns images as ndarrays of HWC format. But the color convention is not the same, and from our perspective there is no way to know if the array is indeed HWC or not (imagine multi-band images for example).
What scaries me is that the ndarray that is passed could also be CHW for some reason, and the function would just return something wrong.
For that reason, we could try to make the scope of this function to be as narrow as possible, so that we can be sure we won't be mishandling user inputs.
PIL Images and AccImage have a well-defined format representation (although I'm not sure that many people use AccImage actually), which is not the case for ndarrays which are generic data containers.
Let me know what you think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am okay with the narrowed scope and have made changes to reflect that.
As a separate point I would like to keep the functionality of ToPILImage to accept numpy arrays. That way users can still load in numpy (or other formats that they can convert to numpy) and as long as the user can convert the numpy array to PIL Image the sequence of compose will still work (as written below).
transforms.Compose([
transforms.ToPILImage(),
...
transforms.PILToTensor(),
])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, changing ToPILImage
was not in the plans (it would be a backwards-incompatible change)
torchvision/transforms/functional.py
Outdated
def as_tensor(pic): | ||
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor of same type. | ||
|
||
See ``AsTensor`` for more details. | ||
|
||
Args: | ||
pic (PIL Image or numpy.ndarray): Image to be converted to tensor. | ||
|
||
Returns: | ||
Tensor: Converted image. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we indeed only consider that this function only supports the PIL -> tensor
conversion, then maybe a better name would be pil_to_tensor
or something like that? Open to suggestions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you take a second pass at your earliest convenience?
One of the tests is a little awkward in that ToPILImage converts FloatTensors to bytes.
The other thing was I'm unsure of the parameter name "swap_to_channelsfirst".
Let me know.
Removes the extra if conditionals and replaces from_numpy with as_tensor.
Renames the function as_tensor to pil_to_tensor and narrows the scope of the function. At the same time also creates a flag that defaults to True for swapping to the channels first format.
Renames the function AsTensor to PILToImage and modifies the description. Adds the swap_to_channelsfirst boolean variable to indicate if the user wishes to change the shape of the input.
Add the __init__ function to PILToTensor since it contains the swap_to_channelsfirst parameter now.
remove trailing white space
Reflects the name change to PILToTensor and the parameter to the function as well as the new narrowed scope that the function only accepts PIL images.
Add img.view(pic.size[1], pic.size[0], len(pic.getbands())) back to outside the if condition.
Torch tranpose operates differently than numpy transpose. Changed operation to permute.
Add mode information when converting to PIL Image from Float Tensor.
Codecov Report
@@ Coverage Diff @@
## master #2092 +/- ##
=========================================
- Coverage 0.48% 0.48% -0.01%
=========================================
Files 92 92
Lines 7409 7428 +19
Branches 1128 1131 +3
=========================================
Hits 36 36
- Misses 7360 7379 +19
Partials 13 13
Continue to review full report at Codecov.
|
return torch.as_tensor(nppic) | ||
|
||
# handle PIL Image | ||
img = torch.as_tensor(np.asarray(pic)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line will still produce the same bug mentioned in #2194. Converting to numpy with np.asarray(pic)
will keep the PIL image non-writeable. If instead we use np.array(pic)
, the bug #2194 would not appear. But I believe this is no real fix because np.array(pic)
copies the data which might be unintended behavior here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we will wait until PyTorch fixes this behavior in master, as making a copy would be fairly expensive.
If the warnings are too annoying, an alternative would be to only use asarray
if the array is writeable, and use array
if it is non-writeable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very sorry for the delay in reviewing this again.
I think the PR is in a very good shape, thanks a lot @xksteven !
I thought a bit more about the swap_to_channelsfirst
, and I think it's better to always swap to CHW. The reason being that all operations in PyTorch expects CHW sizes for images, even if the memory format (due to strides) could be in HWC. So I think that we should follow this, and once memory_format support is more widespread, we can set this flag to the function.
Once this flag is removed, this PR is ready to be merged.
torchvision/transforms/functional.py
Outdated
|
||
Args: | ||
pic (PIL Image): Image to be converted to tensor. | ||
swap_to_channelsfirst (bool): Boolean indicator to convert to CHW format. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after a second thought, let's remove this flag and always perform the transpose to CHW format.
Makes the channel swapping the default behavior.
Remove the swap_channelsfirst argument and makes the swapping the default functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @xksteven !
This adds an as_tensor function as discussed in #2060 (comment).
The idea behind this function is to first convert the image into a torch.Tensor of the same dtype as the inputted image.
I do not know have an example TIFF image to test out if this issue #856 (comment) is addressed or not. Please instruct me on how to proceed. Thanks