@@ -1251,6 +1251,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
12511251 return self ._center_crop (x )
12521252
12531253
1254+ # Define TransformationRobustness defaults externally for easier Sphinx docs formatting
1255+ _TR_TRANSLATE : List [int ] = [4 ] * 10
1256+ _TR_SCALE : List [float ] = [0.995 ** n for n in range (- 5 , 80 )] + [
1257+ 0.998 ** n for n in 2 * list (range (20 , 40 ))
1258+ ]
1259+ _TR_DEGREES : List [int ] = (
1260+ list (range (- 20 , 20 )) + list (range (- 10 , 10 )) + list (range (- 5 , 5 )) + 5 * [0 ]
1261+ )
1262+
1263+
12541264class TransformationRobustness (nn .Module ):
12551265 """
12561266 This transform combines the standard transforms (:class:`.RandomSpatialJitter`,
@@ -1269,15 +1279,9 @@ class TransformationRobustness(nn.Module):
12691279 def __init__ (
12701280 self ,
12711281 padding_transform : Optional [nn .Module ] = nn .ConstantPad2d (2 , value = 0.5 ),
1272- translate : Optional [Union [int , List [int ]]] = [4 ] * 10 ,
1273- scale : Optional [NumSeqOrTensorOrProbDistType ] = [
1274- 0.995 ** n for n in range (- 5 , 80 )
1275- ]
1276- + [0.998 ** n for n in 2 * list (range (20 , 40 ))],
1277- degrees : Optional [NumSeqOrTensorOrProbDistType ] = list (range (- 20 , 20 ))
1278- + list (range (- 10 , 10 ))
1279- + list (range (- 5 , 5 ))
1280- + 5 * [0 ],
1282+ translate : Optional [Union [int , List [int ]]] = _TR_TRANSLATE ,
1283+ scale : Optional [NumSeqOrTensorOrProbDistType ] = _TR_SCALE ,
1284+ degrees : Optional [NumSeqOrTensorOrProbDistType ] = _TR_DEGREES ,
12811285 final_translate : Optional [int ] = 2 ,
12821286 crop_or_pad_output : bool = False ,
12831287 ) -> None :
0 commit comments