Skip to content

Commit

Permalink
Add non-TS'able _resize_image_and_masks variant with less tensor ops (#…
Browse files Browse the repository at this point in the history
…7592)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
  • Loading branch information
ezyang authored May 20, 2023
1 parent d2f7486 commit 300a909
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions torchvision/models/detection/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def _fake_cast_onnx(v: Tensor) -> float:

def _resize_image_and_masks(
image: Tensor,
self_min_size: float,
self_max_size: float,
self_min_size: int,
self_max_size: int,
target: Optional[Dict[str, Tensor]] = None,
fixed_size: Optional[Tuple[int, int]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
Expand All @@ -40,14 +40,24 @@ def _resize_image_and_masks(
if fixed_size is not None:
size = [fixed_size[1], fixed_size[0]]
else:
min_size = torch.min(im_shape).to(dtype=torch.float32)
max_size = torch.max(im_shape).to(dtype=torch.float32)
scale = torch.min(self_min_size / min_size, self_max_size / max_size)
if torch.jit.is_scripting() or torchvision._is_tracing():
min_size = torch.min(im_shape).to(dtype=torch.float32)
max_size = torch.max(im_shape).to(dtype=torch.float32)
self_min_size_f = float(self_min_size)
self_max_size_f = float(self_max_size)
scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size)

if torchvision._is_tracing():
scale_factor = _fake_cast_onnx(scale)
else:
scale_factor = scale.item()

if torchvision._is_tracing():
scale_factor = _fake_cast_onnx(scale)
else:
scale_factor = scale.item()
# Do it the normal way
min_size = min(im_shape)
max_size = max(im_shape)
scale_factor = min(self_min_size / min_size, self_max_size / max_size)

recompute_scale_factor = True

image = torch.nn.functional.interpolate(
Expand Down Expand Up @@ -159,8 +169,7 @@ def normalize(self, image: Tensor) -> Tensor:
def torch_choice(self, k: List[int]) -> int:
"""
Implements `random.choice` via torch ops, so it can be compiled with
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
is fixed.
TorchScript and we use PyTorch's RNG (not native RNG)
"""
index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
return k[index]
Expand All @@ -174,11 +183,10 @@ def resize(
if self.training:
if self._skip_resize:
return image, target
size = float(self.torch_choice(self.min_size))
size = self.torch_choice(self.min_size)
else:
# FIXME assume for now that testing uses the largest scale
size = float(self.min_size[-1])
image, target = _resize_image_and_masks(image, size, float(self.max_size), target, self.fixed_size)
size = self.min_size[-1]
image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size)

if target is None:
return image, target
Expand Down

0 comments on commit 300a909

Please sign in to comment.