77from torchvision .transforms import functional_pil as _FP , functional_tensor as _FT
88
99
10- get_dimensions_image_tensor = _FT .get_dimensions
10+ def get_dimensions_image_tensor (image : torch .Tensor ) -> List [int ]:
11+ chw = list (image .shape [- 3 :])
12+ ndims = len (chw )
13+ if ndims == 3 :
14+ return chw
15+ elif ndims == 2 :
16+ chw .insert (0 , 1 )
17+ return chw
18+ else :
19+ raise TypeError (f"Input tensor should have at least two dimensions, but got { ndims } " )
20+
21+
1122get_dimensions_image_pil = _FP .get_dimensions
1223
1324
@@ -24,7 +35,17 @@ def get_dimensions(image: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -
2435 return get_dimensions_image_pil (image )
2536
2637
27- get_num_channels_image_tensor = _FT .get_image_num_channels
38+ def get_num_channels_image_tensor (image : torch .Tensor ) -> int :
39+ chw = image .shape [- 3 :]
40+ ndims = len (chw )
41+ if ndims == 3 :
42+ return chw [0 ]
43+ elif ndims == 2 :
44+ return 1
45+ else :
46+ raise TypeError (f"Input tensor should have at least two dimensions, but got { ndims } " )
47+
48+
2849get_num_channels_image_pil = _FP .get_image_num_channels
2950
3051
@@ -36,11 +57,11 @@ def get_num_channels(image: Union[features.ImageTypeJIT, features.VideoTypeJIT])
3657 if isinstance (image , torch .Tensor ) and (
3758 torch .jit .is_scripting () or not isinstance (image , (features .Image , features .Video ))
3859 ):
39- return _FT . get_image_num_channels (image )
60+ return get_num_channels_image_tensor (image )
4061 elif isinstance (image , (features .Image , features .Video )):
4162 return image .num_channels
4263 else :
43- return _FP . get_image_num_channels (image )
64+ return get_num_channels_image_pil (image )
4465
4566
4667# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
@@ -49,8 +70,12 @@ def get_num_channels(image: Union[features.ImageTypeJIT, features.VideoTypeJIT])
4970
5071
5172def get_spatial_size_image_tensor (image : torch .Tensor ) -> List [int ]:
52- width , height = _FT .get_image_size (image )
53- return [height , width ]
73+ hw = list (image .shape [- 2 :])
74+ ndims = len (hw )
75+ if ndims == 2 :
76+ return hw
77+ else :
78+ raise TypeError (f"Input tensor should have at least two dimensions, but got { ndims } " )
5479
5580
5681@torch .jit .unused
0 commit comments