Skip to content

Commit 98f1a40

Browse files
authored
Replace getbands() with get_image_num_channels() (#6941)
1 parent ffd5a56 commit 98f1a40

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchvision/transforms/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def to_tensor(pic) -> Tensor:
167167

168168
if pic.mode == "1":
169169
img = 255 * img
170-
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
170+
img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
171171
# put it from HWC to CHW format
172172
img = img.permute((2, 0, 1)).contiguous()
173173
if isinstance(img, torch.ByteTensor):
@@ -205,7 +205,7 @@ def pil_to_tensor(pic: Any) -> Tensor:
205205

206206
# handle PIL Image
207207
img = torch.as_tensor(np.array(pic, copy=True))
208-
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
208+
img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
209209
# put it from HWC to CHW format
210210
img = img.permute((2, 0, 1))
211211
return img

0 commit comments

Comments
 (0)