diff --git a/test/test_ops.py b/test/test_ops.py index 52a66f380a6..99b259f73f5 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -14,6 +14,7 @@ from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps from PIL import Image from torch import nn, Tensor +from torch._dynamo.utils import is_compile_supported from torch.autograd import gradcheck from torch.nn.modules.utils import _pair from torchvision import models, ops @@ -529,6 +530,10 @@ def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype): def test_backward(self, seed, device, contiguous, deterministic): if deterministic and device == "cpu": pytest.skip("cpu is always deterministic, don't retest") + if deterministic and device == "mps": + pytest.skip("no deterministic implementation for mps") + if deterministic and not is_compile_supported(device): + pytest.skip("deterministic implementation only if torch.compile supported") super().test_backward(seed, device, contiguous, deterministic) def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 0d505c140ee..ac1ec8b429a 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -1,9 +1,11 @@ +import functools from typing import List, Union import torch import torch._dynamo import torch.fx from torch import nn, Tensor +from torch._dynamo.utils import is_compile_supported from torch.jit.annotations import BroadcastingList2 from torch.nn.modules.utils import _pair from torchvision.extension import _assert_has_ops, _has_ops @@ -12,6 +14,24 @@ from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format +def lazy_compile(**compile_kwargs): + """Lazily wrap a function with torch.compile on the first call + + This avoids eagerly importing dynamo. + """ + + def decorate_fn(fn): + @functools.wraps(fn) + def compile_hook(*args, **kwargs): + compiled_fn = torch.compile(fn, **compile_kwargs) + globals()[fn.__name__] = functools.wraps(fn)(compiled_fn) + return compiled_fn(*args, **kwargs) + + return compile_hook + + return decorate_fn + + # NB: all inputs are tensors def _bilinear_interpolate( input, # [N, C, H, W] @@ -86,15 +106,13 @@ def maybe_cast(tensor): return tensor -# This is a slow but pure Python and differentiable implementation of -# roi_align. It potentially is a good basis for Inductor compilation -# (but I have not benchmarked it) but today it is solely used for the -# fact that its backwards can be implemented deterministically, -# which is needed for the PT2 benchmark suite. -# +# This is a pure Python and differentiable implementation of roi_align. When +# run in eager mode, it uses a lot of memory, but when compiled it has +# acceptable memory usage. The main point of this implementation is that +# its backwards is deterministic. # It is transcribed directly off of the roi_align CUDA kernel, see # https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266 -@torch._dynamo.allow_in_graph +@lazy_compile(dynamic=True) def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): orig_dtype = input.dtype @@ -232,7 +250,9 @@ def roi_align( if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) if not torch.jit.is_scripting(): - if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)): + if ( + not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)) + ) and is_compile_supported(input.device.type): return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) _assert_has_ops() return torch.ops.torchvision.roi_align(