@@ -24,8 +24,8 @@ def _fake_cast_onnx(v: Tensor) -> float:
2424
2525def _resize_image_and_masks (
2626 image : Tensor ,
27- self_min_size : float ,
28- self_max_size : float ,
27+ self_min_size : int ,
28+ self_max_size : int ,
2929 target : Optional [Dict [str , Tensor ]] = None ,
3030 fixed_size : Optional [Tuple [int , int ]] = None ,
3131) -> Tuple [Tensor , Optional [Dict [str , Tensor ]]]:
@@ -40,14 +40,24 @@ def _resize_image_and_masks(
4040 if fixed_size is not None :
4141 size = [fixed_size [1 ], fixed_size [0 ]]
4242 else :
43- min_size = torch .min (im_shape ).to (dtype = torch .float32 )
44- max_size = torch .max (im_shape ).to (dtype = torch .float32 )
45- scale = torch .min (self_min_size / min_size , self_max_size / max_size )
43+ if torch .jit .is_scripting () or torchvision ._is_tracing ():
44+ min_size = torch .min (im_shape ).to (dtype = torch .float32 )
45+ max_size = torch .max (im_shape ).to (dtype = torch .float32 )
46+ self_min_size_f = float (self_min_size )
47+ self_max_size_f = float (self_max_size )
48+ scale = torch .min (self_min_size_f / min_size , self_max_size_f / max_size )
49+
50+ if torchvision ._is_tracing ():
51+ scale_factor = _fake_cast_onnx (scale )
52+ else :
53+ scale_factor = scale .item ()
4654
47- if torchvision ._is_tracing ():
48- scale_factor = _fake_cast_onnx (scale )
4955 else :
50- scale_factor = scale .item ()
56+ # Do it the normal way
57+ min_size = min (im_shape )
58+ max_size = max (im_shape )
59+ scale_factor = min (self_min_size / min_size , self_max_size / max_size )
60+
5161 recompute_scale_factor = True
5262
5363 image = torch .nn .functional .interpolate (
@@ -159,8 +169,7 @@ def normalize(self, image: Tensor) -> Tensor:
159169 def torch_choice (self , k : List [int ]) -> int :
160170 """
161171 Implements `random.choice` via torch ops, so it can be compiled with
162- TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
163- is fixed.
172+ TorchScript and we use PyTorch's RNG (not native RNG)
164173 """
165174 index = int (torch .empty (1 ).uniform_ (0.0 , float (len (k ))).item ())
166175 return k [index ]
@@ -174,11 +183,10 @@ def resize(
174183 if self .training :
175184 if self ._skip_resize :
176185 return image , target
177- size = float ( self .torch_choice (self .min_size ) )
186+ size = self .torch_choice (self .min_size )
178187 else :
179- # FIXME assume for now that testing uses the largest scale
180- size = float (self .min_size [- 1 ])
181- image , target = _resize_image_and_masks (image , size , float (self .max_size ), target , self .fixed_size )
188+ size = self .min_size [- 1 ]
189+ image , target = _resize_image_and_masks (image , size , self .max_size , target , self .fixed_size )
182190
183191 if target is None :
184192 return image , target
0 commit comments