diff --git a/torchvision/prototype/datapoints/_datapoint.py b/torchvision/prototype/datapoints/_datapoint.py index 53d1b05fb3b..659d4e958cc 100644 --- a/torchvision/prototype/datapoints/_datapoint.py +++ b/torchvision/prototype/datapoints/_datapoint.py @@ -5,7 +5,7 @@ import PIL.Image import torch -from torch._C import DisableTorchFunction +from torch._C import DisableTorchFunctionSubclass from torch.types import _device, _dtype, _size from torchvision.transforms import InterpolationMode @@ -87,7 +87,7 @@ def __torch_function__( if not all(issubclass(cls, t) for t in types): return NotImplemented - with DisableTorchFunction(): + with DisableTorchFunctionSubclass(): output = func(*args, **kwargs or dict()) wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func) @@ -129,22 +129,22 @@ def _F(self) -> ModuleType: # this way we return the result without passing into __torch_function__ @property def shape(self) -> _size: # type: ignore[override] - with DisableTorchFunction(): + with DisableTorchFunctionSubclass(): return super().shape @property def ndim(self) -> int: # type: ignore[override] - with DisableTorchFunction(): + with DisableTorchFunctionSubclass(): return super().ndim @property def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override] - with DisableTorchFunction(): + with DisableTorchFunctionSubclass(): return super().device @property def dtype(self) -> _dtype: # type: ignore[override] - with DisableTorchFunction(): + with DisableTorchFunctionSubclass(): return super().dtype def horizontal_flip(self) -> Datapoint: