Skip to content

Commit

Permalink
refactor torch utils
Browse files Browse the repository at this point in the history
  • Loading branch information
fcakyon committed Jun 19, 2022
1 parent d1d8f3c commit 0f122f0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 27 deletions.
50 changes: 25 additions & 25 deletions sahi/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,37 @@ def empty_cuda_cache():
raise RuntimeError("CUDA not available.")


if is_available("torch"):

@check_requirements(["torch"])
def to_float_tensor(img):
"""
Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W).
Args:
img: np.ndarray
Returns:
torch.tensor
"""
import torch
@check_requirements(["torch"])
def to_float_tensor(img):
"""
Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W).
Args:
img: np.ndarray
Returns:
torch.tensor
"""
import torch

img = img.transpose((2, 0, 1))
img = torch.from_numpy(img).float()
if img.max() > 1:
img /= 255
img = img.transpose((2, 0, 1))
img = torch.from_numpy(img).float()
if img.max() > 1:
img /= 255

return img
return img

@check_requirements(["torch"])
def torch_to_numpy(img):
import torch

img = img.numpy()
if img.max() > 1:
img /= 255
return img.transpose((1, 2, 0))
@check_requirements(["torch"])
def torch_to_numpy(img):
import torch

img = img.numpy()
if img.max() > 1:
img /= 255
return img.transpose((1, 2, 0))


@check_requirements(["torch"])
def is_torch_cuda_available():
if is_available("torch"):
import torch
Expand Down
4 changes: 2 additions & 2 deletions tests/test_detectron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from sahi.model import Detectron2DetectionModel
from sahi.utils.cv import read_image
from sahi.utils.detectron2 import Detectron2TestConstants
from sahi.utils.import_utils import _torch_version
from sahi.utils.import_utils import get_package_info

MODEL_DEVICE = "cpu"
CONFIDENCE_THRESHOLD = 0.5
IMAGE_SIZE = 320

# note that detectron2 binaries are available only for linux

if _torch_version == "1.10.2":
if get_package_info("torch", verbose=False)[1] == "1.10.2":

class TestDetectron2DetectionModel(unittest.TestCase):
def test_load_model(self):
Expand Down

0 comments on commit 0f122f0

Please sign in to comment.