From 300a90926e88f13abbaf3d8155cdba36aab86ab4 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sat, 20 May 2023 16:05:57 -0400 Subject: [PATCH] Add non-TS'able _resize_image_and_masks variant with less tensor ops (#7592) Signed-off-by: Edward Z. Yang --- torchvision/models/detection/transform.py | 36 ++++++++++++++--------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 589d5e45bdc..658c9e83455 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -24,8 +24,8 @@ def _fake_cast_onnx(v: Tensor) -> float: def _resize_image_and_masks( image: Tensor, - self_min_size: float, - self_max_size: float, + self_min_size: int, + self_max_size: int, target: Optional[Dict[str, Tensor]] = None, fixed_size: Optional[Tuple[int, int]] = None, ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: @@ -40,14 +40,24 @@ def _resize_image_and_masks( if fixed_size is not None: size = [fixed_size[1], fixed_size[0]] else: - min_size = torch.min(im_shape).to(dtype=torch.float32) - max_size = torch.max(im_shape).to(dtype=torch.float32) - scale = torch.min(self_min_size / min_size, self_max_size / max_size) + if torch.jit.is_scripting() or torchvision._is_tracing(): + min_size = torch.min(im_shape).to(dtype=torch.float32) + max_size = torch.max(im_shape).to(dtype=torch.float32) + self_min_size_f = float(self_min_size) + self_max_size_f = float(self_max_size) + scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size) + + if torchvision._is_tracing(): + scale_factor = _fake_cast_onnx(scale) + else: + scale_factor = scale.item() - if torchvision._is_tracing(): - scale_factor = _fake_cast_onnx(scale) else: - scale_factor = scale.item() + # Do it the normal way + min_size = min(im_shape) + max_size = max(im_shape) + scale_factor = min(self_min_size / min_size, self_max_size / max_size) + recompute_scale_factor = True image = torch.nn.functional.interpolate( @@ -159,8 +169,7 @@ def normalize(self, image: Tensor) -> Tensor: def torch_choice(self, k: List[int]) -> int: """ Implements `random.choice` via torch ops, so it can be compiled with - TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803 - is fixed. + TorchScript and we use PyTorch's RNG (not native RNG) """ index = int(torch.empty(1).uniform_(0.0, float(len(k))).item()) return k[index] @@ -174,11 +183,10 @@ def resize( if self.training: if self._skip_resize: return image, target - size = float(self.torch_choice(self.min_size)) + size = self.torch_choice(self.min_size) else: - # FIXME assume for now that testing uses the largest scale - size = float(self.min_size[-1]) - image, target = _resize_image_and_masks(image, size, float(self.max_size), target, self.fixed_size) + size = self.min_size[-1] + image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size) if target is None: return image, target