Skip to content

Commit 0d7807d

Browse files
authored
[prototype] Cleaning up the size dimension methods (#6828)
* Cleaning up the size dimension methods. * Change error messages.
1 parent 7278abe commit 0d7807d

File tree

1 file changed

+31
-6
lines changed
  • torchvision/prototype/transforms/functional

1 file changed

+31
-6
lines changed

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,18 @@
77
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
88

99

10-
get_dimensions_image_tensor = _FT.get_dimensions
10+
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
11+
chw = list(image.shape[-3:])
12+
ndims = len(chw)
13+
if ndims == 3:
14+
return chw
15+
elif ndims == 2:
16+
chw.insert(0, 1)
17+
return chw
18+
else:
19+
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
20+
21+
1122
get_dimensions_image_pil = _FP.get_dimensions
1223

1324

@@ -24,7 +35,17 @@ def get_dimensions(image: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -
2435
return get_dimensions_image_pil(image)
2536

2637

27-
get_num_channels_image_tensor = _FT.get_image_num_channels
38+
def get_num_channels_image_tensor(image: torch.Tensor) -> int:
39+
chw = image.shape[-3:]
40+
ndims = len(chw)
41+
if ndims == 3:
42+
return chw[0]
43+
elif ndims == 2:
44+
return 1
45+
else:
46+
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
47+
48+
2849
get_num_channels_image_pil = _FP.get_image_num_channels
2950

3051

@@ -36,11 +57,11 @@ def get_num_channels(image: Union[features.ImageTypeJIT, features.VideoTypeJIT])
3657
if isinstance(image, torch.Tensor) and (
3758
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
3859
):
39-
return _FT.get_image_num_channels(image)
60+
return get_num_channels_image_tensor(image)
4061
elif isinstance(image, (features.Image, features.Video)):
4162
return image.num_channels
4263
else:
43-
return _FP.get_image_num_channels(image)
64+
return get_num_channels_image_pil(image)
4465

4566

4667
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
@@ -49,8 +70,12 @@ def get_num_channels(image: Union[features.ImageTypeJIT, features.VideoTypeJIT])
4970

5071

5172
def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]:
52-
width, height = _FT.get_image_size(image)
53-
return [height, width]
73+
hw = list(image.shape[-2:])
74+
ndims = len(hw)
75+
if ndims == 2:
76+
return hw
77+
else:
78+
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
5479

5580

5681
@torch.jit.unused

0 commit comments

Comments
 (0)