Skip to content

Commit

Permalink
Force use of torch.compile on deterministic roi_align implementation (#…
Browse files Browse the repository at this point in the history
…8436)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
  • Loading branch information
ezyang and NicolasHug authored May 29, 2024
1 parent 775dd2d commit a5f531a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
5 changes: 5 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
from PIL import Image
from torch import nn, Tensor
from torch._dynamo.utils import is_compile_supported
from torch.autograd import gradcheck
from torch.nn.modules.utils import _pair
from torchvision import models, ops
Expand Down Expand Up @@ -529,6 +530,10 @@ def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype):
def test_backward(self, seed, device, contiguous, deterministic):
if deterministic and device == "cpu":
pytest.skip("cpu is always deterministic, don't retest")
if deterministic and device == "mps":
pytest.skip("no deterministic implementation for mps")
if deterministic and not is_compile_supported(device):
pytest.skip("deterministic implementation only if torch.compile supported")
super().test_backward(seed, device, contiguous, deterministic)

def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
Expand Down
36 changes: 28 additions & 8 deletions torchvision/ops/roi_align.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import functools
from typing import List, Union

import torch
import torch._dynamo
import torch.fx
from torch import nn, Tensor
from torch._dynamo.utils import is_compile_supported
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops, _has_ops
Expand All @@ -12,6 +14,24 @@
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format


def lazy_compile(**compile_kwargs):
"""Lazily wrap a function with torch.compile on the first call
This avoids eagerly importing dynamo.
"""

def decorate_fn(fn):
@functools.wraps(fn)
def compile_hook(*args, **kwargs):
compiled_fn = torch.compile(fn, **compile_kwargs)
globals()[fn.__name__] = functools.wraps(fn)(compiled_fn)
return compiled_fn(*args, **kwargs)

return compile_hook

return decorate_fn


# NB: all inputs are tensors
def _bilinear_interpolate(
input, # [N, C, H, W]
Expand Down Expand Up @@ -86,15 +106,13 @@ def maybe_cast(tensor):
return tensor


# This is a slow but pure Python and differentiable implementation of
# roi_align. It potentially is a good basis for Inductor compilation
# (but I have not benchmarked it) but today it is solely used for the
# fact that its backwards can be implemented deterministically,
# which is needed for the PT2 benchmark suite.
#
# This is a pure Python and differentiable implementation of roi_align. When
# run in eager mode, it uses a lot of memory, but when compiled it has
# acceptable memory usage. The main point of this implementation is that
# its backwards is deterministic.
# It is transcribed directly off of the roi_align CUDA kernel, see
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
@torch._dynamo.allow_in_graph
@lazy_compile(dynamic=True)
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
orig_dtype = input.dtype

Expand Down Expand Up @@ -232,7 +250,9 @@ def roi_align(
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
if not torch.jit.is_scripting():
if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)):
if (
not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps))
) and is_compile_supported(input.device.type):
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
_assert_has_ops()
return torch.ops.torchvision.roi_align(
Expand Down

0 comments on commit a5f531a

Please sign in to comment.