diff --git a/datumaro/datumaro/util/image.py b/datumaro/datumaro/util/image.py index 784a1218772..02364de90ea 100644 --- a/datumaro/datumaro/util/image.py +++ b/datumaro/datumaro/util/image.py @@ -28,7 +28,7 @@ def load_image(path): if _IMAGE_BACKEND == _IMAGE_BACKENDS.cv2: import cv2 - image = cv2.imread(path) + image = cv2.imread(path, cv2.IMREAD_UNCHANGED) image = image.astype(np.float32) elif _IMAGE_BACKEND == _IMAGE_BACKENDS.PIL: from PIL import Image @@ -39,13 +39,15 @@ def load_image(path): else: raise NotImplementedError() - assert len(image.shape) == 3 - assert image.shape[2] in [1, 3, 4] + assert len(image.shape) in [2, 3] + if len(image.shape) == 3: + assert image.shape[2] in [3, 4] return image def save_image(path, image, params=None): if _IMAGE_BACKEND == _IMAGE_BACKENDS.cv2: import cv2 + image = image.astype(np.uint8) cv2.imwrite(path, image, params=params) elif _IMAGE_BACKEND == _IMAGE_BACKENDS.PIL: from PIL import Image @@ -109,8 +111,9 @@ def decode_image(image_bytes): else: raise NotImplementedError() - assert len(image.shape) == 3 - assert image.shape[2] in [1, 3, 4] + assert len(image.shape) in [2, 3] + if len(image.shape) == 3: + assert image.shape[2] in [3, 4] return image diff --git a/datumaro/tests/test_image.py b/datumaro/tests/test_image.py index 143d6c4e41e..424fd9c88dd 100644 --- a/datumaro/tests/test_image.py +++ b/datumaro/tests/test_image.py @@ -17,9 +17,12 @@ def tearDown(self): def test_save_and_load_backends(self): backends = image_module._IMAGE_BACKENDS - for save_backend, load_backend in product(backends, backends): + for save_backend, load_backend, c in product(backends, backends, [1, 3]): with TestDir() as test_dir: - src_image = np.random.randint(0, 255 + 1, (2, 4, 3)) + if c == 1: + src_image = np.random.randint(0, 255 + 1, (2, 4)) + else: + src_image = np.random.randint(0, 255 + 1, (2, 4, c)) path = osp.join(test_dir.path, 'img.png') # lossless image_module._IMAGE_BACKEND = save_backend @@ -33,8 +36,11 @@ def test_save_and_load_backends(self): def test_encode_and_decode_backends(self): backends = image_module._IMAGE_BACKENDS - for save_backend, load_backend in product(backends, backends): - src_image = np.random.randint(0, 255 + 1, (2, 4, 3)) + for save_backend, load_backend, c in product(backends, backends, [1, 3]): + if c == 1: + src_image = np.random.randint(0, 255 + 1, (2, 4)) + else: + src_image = np.random.randint(0, 255 + 1, (2, 4, c)) image_module._IMAGE_BACKEND = save_backend buffer = image_module.encode_image(src_image, '.png') # lossless