Skip to content

Add compatibility checks for C++ extensions #2467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions torchvision/extension.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
_HAS_OPS = False
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can probably be removed now



def _has_ops():
return False


def _register_extensions():
import os
import importlib
Expand All @@ -23,10 +27,26 @@ def _register_extensions():
try:
_register_extensions()
_HAS_OPS = True

def _has_ops(): # noqa: F811
return True
except (ImportError, OSError):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing this!
Perhaps it's possible to define _has_ops function here that returns (bool has_ops, str err_msg). Then the exact error message from the exception can be used in _assert_has_ops

Copy link
Member Author

@fmassa fmassa Oct 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review, I forgot to answer this.

I think this should be possible to be done, let me try this.

Given that torchscript doesn't support globals, I think the only way we could have this is to have pre-canned error messages that are used to construct the _has_ops function.

Something like

try:
    _register_extensions()
    ...
except (ImportError, OSError) as err:
    err_str = repr(err)
    if err_type_1 in err_str:
        def _has_ops():
            ...
    elif err_type_2 in err_str:
        ...

which makes it a bit less readable, and kinds of defeats the purpose of having the error message being returned because we would only be able to nicely print a few types of errors (the ones we hard-coded).

But maybe I'm missing something here?



def _assert_has_ops():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should do something similar for the video and image ops, but I'll do that in a follow-up PR once we define more precisely if this implementation is enough

if not _has_ops():
raise RuntimeError(
"Couldn't load custom C++ ops. This can happen if your PyTorch and "
"torchvision versions are incompatible, or if you had errors while compiling "
"torchvision from source. For further information on the compatible versions, check "
"https://github.com/pytorch/vision#installation for the compatibility matrix. "
"Please check your PyTorch version with torch.__version__ and your torchvision "
"version with torchvision.__version__ and verify if they are compatible, and if not "
"please reinstall torchvision so that it matches your PyTorch install."
)


def _check_cuda_version():
"""
Make sure that CUDA versions match between the pytorch install and torchvision install
Expand Down
2 changes: 2 additions & 0 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch import Tensor
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
import torchvision
from torchvision.extension import _assert_has_ops


def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
Expand Down Expand Up @@ -37,6 +38,7 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
of the elements that have been kept
by NMS, sorted in decreasing order of scores
"""
_assert_has_ops()
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)


Expand Down
2 changes: 2 additions & 0 deletions torchvision/ops/deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair
from torch.jit.annotations import Optional, Tuple
from torchvision.extension import _assert_has_ops


def deform_conv2d(
Expand Down Expand Up @@ -51,6 +52,7 @@ def deform_conv2d(
>>> torch.Size([4, 5, 8, 8])
"""

_assert_has_ops()
out_channels = weight.shape[0]
if bias is None:
bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)
Expand Down
2 changes: 2 additions & 0 deletions torchvision/ops/ps_roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, Tuple

from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape


Expand Down Expand Up @@ -38,6 +39,7 @@ def ps_roi_align(
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
_assert_has_ops()
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
Expand Down
2 changes: 2 additions & 0 deletions torchvision/ops/ps_roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, Tuple

from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape


Expand Down Expand Up @@ -32,6 +33,7 @@ def ps_roi_pool(
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
_assert_has_ops()
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
Expand Down
2 changes: 2 additions & 0 deletions torchvision/ops/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, BroadcastingList2

from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape


Expand Down Expand Up @@ -41,6 +42,7 @@ def roi_align(
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
_assert_has_ops()
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
Expand Down
2 changes: 2 additions & 0 deletions torchvision/ops/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, BroadcastingList2

from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape


Expand Down Expand Up @@ -31,6 +32,7 @@ def roi_pool(
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
_assert_has_ops()
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
Expand Down