Skip to content

Commit 300a909

Browse files
authored
Add non-TS'able _resize_image_and_masks variant with less tensor ops (#7592)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
1 parent d2f7486 commit 300a909

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

torchvision/models/detection/transform.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def _fake_cast_onnx(v: Tensor) -> float:
2424

2525
def _resize_image_and_masks(
2626
image: Tensor,
27-
self_min_size: float,
28-
self_max_size: float,
27+
self_min_size: int,
28+
self_max_size: int,
2929
target: Optional[Dict[str, Tensor]] = None,
3030
fixed_size: Optional[Tuple[int, int]] = None,
3131
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
@@ -40,14 +40,24 @@ def _resize_image_and_masks(
4040
if fixed_size is not None:
4141
size = [fixed_size[1], fixed_size[0]]
4242
else:
43-
min_size = torch.min(im_shape).to(dtype=torch.float32)
44-
max_size = torch.max(im_shape).to(dtype=torch.float32)
45-
scale = torch.min(self_min_size / min_size, self_max_size / max_size)
43+
if torch.jit.is_scripting() or torchvision._is_tracing():
44+
min_size = torch.min(im_shape).to(dtype=torch.float32)
45+
max_size = torch.max(im_shape).to(dtype=torch.float32)
46+
self_min_size_f = float(self_min_size)
47+
self_max_size_f = float(self_max_size)
48+
scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size)
49+
50+
if torchvision._is_tracing():
51+
scale_factor = _fake_cast_onnx(scale)
52+
else:
53+
scale_factor = scale.item()
4654

47-
if torchvision._is_tracing():
48-
scale_factor = _fake_cast_onnx(scale)
4955
else:
50-
scale_factor = scale.item()
56+
# Do it the normal way
57+
min_size = min(im_shape)
58+
max_size = max(im_shape)
59+
scale_factor = min(self_min_size / min_size, self_max_size / max_size)
60+
5161
recompute_scale_factor = True
5262

5363
image = torch.nn.functional.interpolate(
@@ -159,8 +169,7 @@ def normalize(self, image: Tensor) -> Tensor:
159169
def torch_choice(self, k: List[int]) -> int:
160170
"""
161171
Implements `random.choice` via torch ops, so it can be compiled with
162-
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
163-
is fixed.
172+
TorchScript and we use PyTorch's RNG (not native RNG)
164173
"""
165174
index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
166175
return k[index]
@@ -174,11 +183,10 @@ def resize(
174183
if self.training:
175184
if self._skip_resize:
176185
return image, target
177-
size = float(self.torch_choice(self.min_size))
186+
size = self.torch_choice(self.min_size)
178187
else:
179-
# FIXME assume for now that testing uses the largest scale
180-
size = float(self.min_size[-1])
181-
image, target = _resize_image_and_masks(image, size, float(self.max_size), target, self.fixed_size)
188+
size = self.min_size[-1]
189+
image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size)
182190

183191
if target is None:
184192
return image, target

0 commit comments

Comments
 (0)