44
55import PIL .Image
66import torch
7- from torch .nn .functional import interpolate
7+ from torch .nn .functional import interpolate , pad as torch_pad
8+
89from torchvision .prototype import features
910from torchvision .transforms import functional_pil as _FP , functional_tensor as _FT
1011from torchvision .transforms .functional import (
1516 pil_to_tensor ,
1617 to_pil_image ,
1718)
18- from torchvision .transforms .functional_tensor import _parse_pad_padding
1919
2020from ._meta import convert_format_bounding_box , get_spatial_size_image_pil
2121
@@ -663,7 +663,28 @@ def rotate(
663663 return rotate_image_pil (inpt , angle , interpolation = interpolation , expand = expand , fill = fill , center = center )
664664
665665
666- pad_image_pil = _FP .pad
666+ def _parse_pad_padding (padding : Union [int , List [int ]]) -> List [int ]:
667+ if isinstance (padding , int ):
668+ pad_left = pad_right = pad_top = pad_bottom = padding
669+ elif isinstance (padding , (tuple , list )):
670+ if len (padding ) == 1 :
671+ pad_left = pad_right = pad_top = pad_bottom = padding [0 ]
672+ elif len (padding ) == 2 :
673+ pad_left = pad_right = padding [0 ]
674+ pad_top = pad_bottom = padding [1 ]
675+ elif len (padding ) == 4 :
676+ pad_left = padding [0 ]
677+ pad_top = padding [1 ]
678+ pad_right = padding [2 ]
679+ pad_bottom = padding [3 ]
680+ else :
681+ raise ValueError (
682+ f"Padding must be an int or a 1, 2, or 4 element tuple, not a { len (padding )} element tuple"
683+ )
684+ else :
685+ raise TypeError (f"`padding` should be an integer or tuple or list of integers, but got { padding } " )
686+
687+ return [pad_left , pad_right , pad_top , pad_bottom ]
667688
668689
669690def pad_image_tensor (
@@ -672,50 +693,86 @@ def pad_image_tensor(
672693 fill : features .FillTypeJIT = None ,
673694 padding_mode : str = "constant" ,
674695) -> torch .Tensor :
696+ # Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses
697+ # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
698+ # internally.
699+ torch_padding = _parse_pad_padding (padding )
700+
701+ if padding_mode not in ["constant" , "edge" , "reflect" , "symmetric" ]:
702+ raise ValueError (
703+ f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
704+ f"but got `'{ padding_mode } '`."
705+ )
706+
675707 if fill is None :
676- # This is a JIT workaround
677- return _pad_with_scalar_fill (image , padding , fill = None , padding_mode = padding_mode )
678- elif isinstance (fill , (int , float )) or len (fill ) == 1 :
679- fill_number = fill [0 ] if isinstance (fill , list ) else fill
680- return _pad_with_scalar_fill (image , padding , fill = fill_number , padding_mode = padding_mode )
708+ fill = 0
709+
710+ if isinstance (fill , (int , float )):
711+ return _pad_with_scalar_fill (image , torch_padding , fill = fill , padding_mode = padding_mode )
712+ elif len (fill ) == 1 :
713+ return _pad_with_scalar_fill (image , torch_padding , fill = fill [0 ], padding_mode = padding_mode )
681714 else :
682- return _pad_with_vector_fill (image , padding , fill = fill , padding_mode = padding_mode )
715+ return _pad_with_vector_fill (image , torch_padding , fill = fill , padding_mode = padding_mode )
683716
684717
685718def _pad_with_scalar_fill (
686719 image : torch .Tensor ,
687- padding : Union [ int , List [int ] ],
688- fill : Union [int , float , None ],
689- padding_mode : str = "constant" ,
720+ torch_padding : List [int ],
721+ fill : Union [int , float ],
722+ padding_mode : str ,
690723) -> torch .Tensor :
691724 shape = image .shape
692725 num_channels , height , width = shape [- 3 :]
693726
694727 if image .numel () > 0 :
695- image = _FT .pad (
696- img = image .reshape (- 1 , num_channels , height , width ), padding = padding , fill = fill , padding_mode = padding_mode
697- )
728+ image = image .reshape (- 1 , num_channels , height , width )
729+
730+ if padding_mode == "edge" :
731+ # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
732+ # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
733+ # name.
734+ padding_mode = "replicate"
735+
736+ if padding_mode == "constant" :
737+ image = torch_pad (image , torch_padding , mode = padding_mode , value = float (fill ))
738+ elif padding_mode in ("reflect" , "replicate" ):
739+ # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
740+ # TODO: See https://github.com/pytorch/pytorch/issues/40763
741+ dtype = image .dtype
742+ if not image .is_floating_point ():
743+ needs_cast = True
744+ image = image .to (torch .float32 )
745+ else :
746+ needs_cast = False
747+
748+ image = torch_pad (image , torch_padding , mode = padding_mode )
749+
750+ if needs_cast :
751+ image = image .to (dtype )
752+ else : # padding_mode == "symmetric"
753+ image = _FT ._pad_symmetric (image , torch_padding )
754+
698755 new_height , new_width = image .shape [- 2 :]
699756 else :
700- left , right , top , bottom = _FT . _parse_pad_padding ( padding )
757+ left , right , top , bottom = torch_padding
701758 new_height = height + top + bottom
702759 new_width = width + left + right
703760
704761 return image .reshape (shape [:- 3 ] + (num_channels , new_height , new_width ))
705762
706763
707- # TODO: This should be removed once pytorch pad supports non-scalar padding values
764+ # TODO: This should be removed once torch_pad supports non-scalar padding values
708765def _pad_with_vector_fill (
709766 image : torch .Tensor ,
710- padding : Union [ int , List [int ] ],
767+ torch_padding : List [int ],
711768 fill : List [float ],
712- padding_mode : str = "constant" ,
769+ padding_mode : str ,
713770) -> torch .Tensor :
714771 if padding_mode != "constant" :
715772 raise ValueError (f"Padding mode '{ padding_mode } ' is not supported if fill is not scalar" )
716773
717- output = _pad_with_scalar_fill (image , padding , fill = 0 , padding_mode = "constant" )
718- left , right , top , bottom = _parse_pad_padding ( padding )
774+ output = _pad_with_scalar_fill (image , torch_padding , fill = 0 , padding_mode = "constant" )
775+ left , right , top , bottom = torch_padding
719776 fill = torch .tensor (fill , dtype = image .dtype , device = image .device ).reshape (- 1 , 1 , 1 )
720777
721778 if top > 0 :
@@ -729,6 +786,9 @@ def _pad_with_vector_fill(
729786 return output
730787
731788
789+ pad_image_pil = _FP .pad
790+
791+
732792def pad_mask (
733793 mask : torch .Tensor ,
734794 padding : Union [int , List [int ]],
0 commit comments