diff --git a/sahi/utils/torch.py b/sahi/utils/torch.py index a5d289678..74a60270f 100644 --- a/sahi/utils/torch.py +++ b/sahi/utils/torch.py @@ -15,37 +15,37 @@ def empty_cuda_cache(): raise RuntimeError("CUDA not available.") -if is_available("torch"): - - @check_requirements(["torch"]) - def to_float_tensor(img): - """ - Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range - [0, 255] to a torch.FloatTensor of shape (C x H x W). - Args: - img: np.ndarray - Returns: - torch.tensor - """ - import torch +@check_requirements(["torch"]) +def to_float_tensor(img): + """ + Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W). + Args: + img: np.ndarray + Returns: + torch.tensor + """ + import torch - img = img.transpose((2, 0, 1)) - img = torch.from_numpy(img).float() - if img.max() > 1: - img /= 255 + img = img.transpose((2, 0, 1)) + img = torch.from_numpy(img).float() + if img.max() > 1: + img /= 255 - return img + return img - @check_requirements(["torch"]) - def torch_to_numpy(img): - import torch - img = img.numpy() - if img.max() > 1: - img /= 255 - return img.transpose((1, 2, 0)) +@check_requirements(["torch"]) +def torch_to_numpy(img): + import torch + + img = img.numpy() + if img.max() > 1: + img /= 255 + return img.transpose((1, 2, 0)) +@check_requirements(["torch"]) def is_torch_cuda_available(): if is_available("torch"): import torch diff --git a/tests/test_detectron2.py b/tests/test_detectron2.py index ce48a53a0..13230fae5 100644 --- a/tests/test_detectron2.py +++ b/tests/test_detectron2.py @@ -6,7 +6,7 @@ from sahi.model import Detectron2DetectionModel from sahi.utils.cv import read_image from sahi.utils.detectron2 import Detectron2TestConstants -from sahi.utils.import_utils import _torch_version +from sahi.utils.import_utils import get_package_info MODEL_DEVICE = "cpu" CONFIDENCE_THRESHOLD = 0.5 @@ -14,7 +14,7 @@ # note that detectron2 binaries are available only for linux -if _torch_version == "1.10.2": +if get_package_info("torch", verbose=False)[1] == "1.10.2": class TestDetectron2DetectionModel(unittest.TestCase): def test_load_model(self):