1717 _check_padding_arg ,
1818 _check_padding_mode_arg ,
1919 _check_sequence_input ,
20+ _get_fill ,
2021 _setup_angle ,
2122 _setup_fill_arg ,
2223 _setup_float_or_seq ,
@@ -487,7 +488,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
487488 def __init__ (
488489 self ,
489490 padding : Union [int , Sequence [int ]],
490- fill : Union [datapoints ._FillType , Dict [Type , datapoints ._FillType ]] = 0 ,
491+ fill : Union [datapoints ._FillType , Dict [Union [ Type , str ] , datapoints ._FillType ]] = 0 ,
491492 padding_mode : Literal ["constant" , "edge" , "reflect" , "symmetric" ] = "constant" ,
492493 ) -> None :
493494 super ().__init__ ()
@@ -504,7 +505,7 @@ def __init__(
504505 self .padding_mode = padding_mode
505506
506507 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
507- fill = self ._fill [ type (inpt )]
508+ fill = _get_fill ( self ._fill , type (inpt ))
508509 return F .pad (inpt , padding = self .padding , fill = fill , padding_mode = self .padding_mode ) # type: ignore[arg-type]
509510
510511
@@ -542,7 +543,7 @@ class RandomZoomOut(_RandomApplyTransform):
542543
543544 def __init__ (
544545 self ,
545- fill : Union [datapoints ._FillType , Dict [Type , datapoints ._FillType ]] = 0 ,
546+ fill : Union [datapoints ._FillType , Dict [Union [ Type , str ] , datapoints ._FillType ]] = 0 ,
546547 side_range : Sequence [float ] = (1.0 , 4.0 ),
547548 p : float = 0.5 ,
548549 ) -> None :
@@ -574,7 +575,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
574575 return dict (padding = padding )
575576
576577 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
577- fill = self ._fill [ type (inpt )]
578+ fill = _get_fill ( self ._fill , type (inpt ))
578579 return F .pad (inpt , ** params , fill = fill )
579580
580581
@@ -620,7 +621,7 @@ def __init__(
620621 interpolation : Union [InterpolationMode , int ] = InterpolationMode .NEAREST ,
621622 expand : bool = False ,
622623 center : Optional [List [float ]] = None ,
623- fill : Union [datapoints ._FillType , Dict [Type , datapoints ._FillType ]] = 0 ,
624+ fill : Union [datapoints ._FillType , Dict [Union [ Type , str ] , datapoints ._FillType ]] = 0 ,
624625 ) -> None :
625626 super ().__init__ ()
626627 self .degrees = _setup_angle (degrees , name = "degrees" , req_sizes = (2 ,))
@@ -640,7 +641,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
640641 return dict (angle = angle )
641642
642643 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
643- fill = self ._fill [ type (inpt )]
644+ fill = _get_fill ( self ._fill , type (inpt ))
644645 return F .rotate (
645646 inpt ,
646647 ** params ,
@@ -702,7 +703,7 @@ def __init__(
702703 scale : Optional [Sequence [float ]] = None ,
703704 shear : Optional [Union [int , float , Sequence [float ]]] = None ,
704705 interpolation : Union [InterpolationMode , int ] = InterpolationMode .NEAREST ,
705- fill : Union [datapoints ._FillType , Dict [Type , datapoints ._FillType ]] = 0 ,
706+ fill : Union [datapoints ._FillType , Dict [Union [ Type , str ] , datapoints ._FillType ]] = 0 ,
706707 center : Optional [List [float ]] = None ,
707708 ) -> None :
708709 super ().__init__ ()
@@ -762,7 +763,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
762763 return dict (angle = angle , translate = translate , scale = scale , shear = shear )
763764
764765 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
765- fill = self ._fill [ type (inpt )]
766+ fill = _get_fill ( self ._fill , type (inpt ))
766767 return F .affine (
767768 inpt ,
768769 ** params ,
@@ -840,7 +841,7 @@ def __init__(
840841 size : Union [int , Sequence [int ]],
841842 padding : Optional [Union [int , Sequence [int ]]] = None ,
842843 pad_if_needed : bool = False ,
843- fill : Union [datapoints ._FillType , Dict [Type , datapoints ._FillType ]] = 0 ,
844+ fill : Union [datapoints ._FillType , Dict [Union [ Type , str ] , datapoints ._FillType ]] = 0 ,
844845 padding_mode : Literal ["constant" , "edge" , "reflect" , "symmetric" ] = "constant" ,
845846 ) -> None :
846847 super ().__init__ ()
@@ -918,7 +919,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
918919
919920 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
920921 if params ["needs_pad" ]:
921- fill = self ._fill [ type (inpt )]
922+ fill = _get_fill ( self ._fill , type (inpt ))
922923 inpt = F .pad (inpt , padding = params ["padding" ], fill = fill , padding_mode = self .padding_mode )
923924
924925 if params ["needs_crop" ]:
@@ -959,7 +960,7 @@ def __init__(
959960 distortion_scale : float = 0.5 ,
960961 p : float = 0.5 ,
961962 interpolation : Union [InterpolationMode , int ] = InterpolationMode .BILINEAR ,
962- fill : Union [datapoints ._FillType , Dict [Type , datapoints ._FillType ]] = 0 ,
963+ fill : Union [datapoints ._FillType , Dict [Union [ Type , str ] , datapoints ._FillType ]] = 0 ,
963964 ) -> None :
964965 super ().__init__ (p = p )
965966
@@ -1002,7 +1003,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
10021003 return dict (coefficients = perspective_coeffs )
10031004
10041005 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
1005- fill = self ._fill [ type (inpt )]
1006+ fill = _get_fill ( self ._fill , type (inpt ))
10061007 return F .perspective (
10071008 inpt ,
10081009 None ,
@@ -1061,7 +1062,7 @@ def __init__(
10611062 alpha : Union [float , Sequence [float ]] = 50.0 ,
10621063 sigma : Union [float , Sequence [float ]] = 5.0 ,
10631064 interpolation : Union [InterpolationMode , int ] = InterpolationMode .BILINEAR ,
1064- fill : Union [datapoints ._FillType , Dict [Type , datapoints ._FillType ]] = 0 ,
1065+ fill : Union [datapoints ._FillType , Dict [Union [ Type , str ] , datapoints ._FillType ]] = 0 ,
10651066 ) -> None :
10661067 super ().__init__ ()
10671068 self .alpha = _setup_float_or_seq (alpha , "alpha" , 2 )
@@ -1095,7 +1096,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
10951096 return dict (displacement = displacement )
10961097
10971098 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
1098- fill = self ._fill [ type (inpt )]
1099+ fill = _get_fill ( self ._fill , type (inpt ))
10991100 return F .elastic (
11001101 inpt ,
11011102 ** params ,
0 commit comments