Skip to content

Commit

Permalink
Avoid _prims_common.check in favor of torch._check (#7625)
Browse files Browse the repository at this point in the history
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
  • Loading branch information
3 people authored May 30, 2023
1 parent 9d0a93e commit f36c5de
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions torchvision/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
# Ensure that torch.ops.torchvision is visible
import torchvision.extension # noqa: F401

from torch._prims_common import check


@functools.lru_cache(None)
def get_meta_lib():
Expand All @@ -25,8 +23,8 @@ def wrapper(fn):

@register_meta("roi_align")
def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
check(
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
torch._check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
Expand All @@ -42,7 +40,7 @@ def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, samp
def meta_roi_align_backward(
grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
):
check(
torch._check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
Expand Down

0 comments on commit f36c5de

Please sign in to comment.