Skip to content

Commit be42ae3

Browse files
Yosua Michael Maranathafacebook-github-bot
authored andcommitted
[fbsync] Replace getbands() with get_image_num_channels() (#6941)
Reviewed By: NicolasHug Differential Revision: D41265203 fbshipit-source-id: 082a28ef8f8809e313c2ecf014ea156c943adb92
1 parent e1efbb4 commit be42ae3

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchvision/transforms/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def to_tensor(pic) -> Tensor:
167167

168168
if pic.mode == "1":
169169
img = 255 * img
170-
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
170+
img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
171171
# put it from HWC to CHW format
172172
img = img.permute((2, 0, 1)).contiguous()
173173
if isinstance(img, torch.ByteTensor):
@@ -205,7 +205,7 @@ def pil_to_tensor(pic: Any) -> Tensor:
205205

206206
# handle PIL Image
207207
img = torch.as_tensor(np.array(pic, copy=True))
208-
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
208+
img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
209209
# put it from HWC to CHW format
210210
img = img.permute((2, 0, 1))
211211
return img

0 commit comments

Comments
 (0)