Skip to content

Commit

Permalink
Add image data type conversion in image saving
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiltsov-max committed Dec 12, 2019
1 parent 65ec148 commit 6c9e243
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
13 changes: 8 additions & 5 deletions datumaro/datumaro/util/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def load_image(path):

if _IMAGE_BACKEND == _IMAGE_BACKENDS.cv2:
import cv2
image = cv2.imread(path)
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
image = image.astype(np.float32)
elif _IMAGE_BACKEND == _IMAGE_BACKENDS.PIL:
from PIL import Image
Expand All @@ -39,13 +39,15 @@ def load_image(path):
else:
raise NotImplementedError()

assert len(image.shape) == 3
assert image.shape[2] in [1, 3, 4]
assert len(image.shape) in [2, 3]
if len(image.shape) == 3:
assert image.shape[2] in [3, 4]
return image

def save_image(path, image, params=None):
if _IMAGE_BACKEND == _IMAGE_BACKENDS.cv2:
import cv2
image = image.astype(np.uint8)
cv2.imwrite(path, image, params=params)
elif _IMAGE_BACKEND == _IMAGE_BACKENDS.PIL:
from PIL import Image
Expand Down Expand Up @@ -109,8 +111,9 @@ def decode_image(image_bytes):
else:
raise NotImplementedError()

assert len(image.shape) == 3
assert image.shape[2] in [1, 3, 4]
assert len(image.shape) in [2, 3]
if len(image.shape) == 3:
assert image.shape[2] in [3, 4]
return image


Expand Down
14 changes: 10 additions & 4 deletions datumaro/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ def tearDown(self):

def test_save_and_load_backends(self):
backends = image_module._IMAGE_BACKENDS
for save_backend, load_backend in product(backends, backends):
for save_backend, load_backend, c in product(backends, backends, [1, 3]):
with TestDir() as test_dir:
src_image = np.random.randint(0, 255 + 1, (2, 4, 3))
if c == 1:
src_image = np.random.randint(0, 255 + 1, (2, 4))
else:
src_image = np.random.randint(0, 255 + 1, (2, 4, c))
path = osp.join(test_dir.path, 'img.png') # lossless

image_module._IMAGE_BACKEND = save_backend
Expand All @@ -33,8 +36,11 @@ def test_save_and_load_backends(self):

def test_encode_and_decode_backends(self):
backends = image_module._IMAGE_BACKENDS
for save_backend, load_backend in product(backends, backends):
src_image = np.random.randint(0, 255 + 1, (2, 4, 3))
for save_backend, load_backend, c in product(backends, backends, [1, 3]):
if c == 1:
src_image = np.random.randint(0, 255 + 1, (2, 4))
else:
src_image = np.random.randint(0, 255 + 1, (2, 4, c))

image_module._IMAGE_BACKEND = save_backend
buffer = image_module.encode_image(src_image, '.png') # lossless
Expand Down

0 comments on commit 6c9e243

Please sign in to comment.